Add AGENTS.md documentation for AI agent guidance
This commit is contained in:
223
AGENTS.md
Normal file
223
AGENTS.md
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# AGENTS.md
|
||||||
|
|
||||||
|
This file provides guidance to agents when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
**watsonx-openai-proxy** is an OpenAI-compatible API proxy for IBM watsonx.ai. It enables any tool or application that supports the OpenAI API format to seamlessly work with watsonx.ai models.
|
||||||
|
|
||||||
|
### Core Purpose
|
||||||
|
- Provide drop-in replacement for OpenAI API endpoints
|
||||||
|
- Translate OpenAI API requests to watsonx.ai API calls
|
||||||
|
- Handle IBM Cloud authentication and token management automatically
|
||||||
|
- Support streaming responses via Server-Sent Events (SSE)
|
||||||
|
|
||||||
|
### Technology Stack
|
||||||
|
- **Framework**: FastAPI (async web framework)
|
||||||
|
- **Language**: Python 3.9+
|
||||||
|
- **HTTP Client**: httpx (async HTTP client)
|
||||||
|
- **Validation**: Pydantic v2 (data validation and settings)
|
||||||
|
- **Server**: uvicorn (ASGI server)
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
The codebase follows a clean, modular architecture:
|
||||||
|
|
||||||
|
```
|
||||||
|
app/
|
||||||
|
├── main.py # FastAPI app initialization, middleware, lifespan management
|
||||||
|
├── config.py # Settings management, model mapping, environment variables
|
||||||
|
├── routers/ # API endpoint handlers (chat, completions, embeddings, models)
|
||||||
|
├── services/ # Business logic (watsonx_service for API interactions)
|
||||||
|
├── models/ # Pydantic models for OpenAI-compatible schemas
|
||||||
|
└── utils/ # Helper functions (request/response transformers)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Design Patterns**:
|
||||||
|
- **Service Layer**: `watsonx_service.py` encapsulates all watsonx.ai API interactions
|
||||||
|
- **Transformer Pattern**: `transformers.py` handles bidirectional conversion between OpenAI and watsonx formats
|
||||||
|
- **Singleton Services**: Global service instances (`watsonx_service`, `settings`) for shared state
|
||||||
|
- **Async/Await**: All I/O operations are asynchronous for better performance
|
||||||
|
- **Middleware**: Custom authentication middleware for optional API key validation
|
||||||
|
|
||||||
|
## Building and Running
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
```bash
|
||||||
|
# Python 3.9 or higher required
|
||||||
|
python --version
|
||||||
|
|
||||||
|
# IBM Cloud credentials needed:
|
||||||
|
# - IBM_CLOUD_API_KEY
|
||||||
|
# - WATSONX_PROJECT_ID
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
```bash
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Configure environment
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your IBM Cloud credentials
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running the Server
|
||||||
|
```bash
|
||||||
|
# Development (with auto-reload)
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
|
||||||
|
# Production (with workers)
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
|
||||||
|
|
||||||
|
# Using Python module
|
||||||
|
python -m app.main
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker Deployment
|
||||||
|
```bash
|
||||||
|
# Build image
|
||||||
|
docker build -t watsonx-openai-proxy .
|
||||||
|
|
||||||
|
# Run container
|
||||||
|
docker run -p 8000:8000 --env-file .env watsonx-openai-proxy
|
||||||
|
|
||||||
|
# Using docker-compose
|
||||||
|
docker-compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
```bash
|
||||||
|
# Install test dependencies
|
||||||
|
pip install pytest pytest-asyncio httpx
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
pytest tests/
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
pytest tests/ --cov=app
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Conventions
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
- **Async First**: Use `async`/`await` for all I/O operations (HTTP requests, file operations)
|
||||||
|
- **Type Hints**: All functions should have type annotations for parameters and return values
|
||||||
|
- **Docstrings**: Use Google-style docstrings for functions and classes
|
||||||
|
- **Logging**: Use the `logging` module with appropriate log levels (info, warning, error)
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- Catch exceptions at router level and return OpenAI-compatible error responses
|
||||||
|
- Use `HTTPException` with proper status codes and error details
|
||||||
|
- Log errors with full context using `logger.error(..., exc_info=True)`
|
||||||
|
- Return structured error responses matching OpenAI's error format
|
||||||
|
|
||||||
|
### Configuration Management
|
||||||
|
- All configuration via environment variables (`.env` file)
|
||||||
|
- Use `pydantic-settings` for type-safe configuration
|
||||||
|
- Model mapping via `MODEL_MAP_*` environment variables
|
||||||
|
- Settings accessed through global `settings` instance
|
||||||
|
|
||||||
|
### Token Management
|
||||||
|
- Bearer tokens automatically refreshed every 50 minutes (expire at 60 minutes)
|
||||||
|
- Token refresh on 401 errors from watsonx.ai
|
||||||
|
- Thread-safe token refresh using `asyncio.Lock`
|
||||||
|
- Initial token obtained during application startup
|
||||||
|
|
||||||
|
### API Compatibility
|
||||||
|
- Maintain strict OpenAI API compatibility in request/response formats
|
||||||
|
- Use Pydantic models from `openai_models.py` for validation
|
||||||
|
- Transform requests/responses using functions in `transformers.py`
|
||||||
|
- Support both streaming and non-streaming responses
|
||||||
|
|
||||||
|
### Adding New Endpoints
|
||||||
|
1. Create router in `app/routers/` (e.g., `new_endpoint.py`)
|
||||||
|
2. Define Pydantic models in `app/models/openai_models.py`
|
||||||
|
3. Add transformation logic in `app/utils/transformers.py`
|
||||||
|
4. Add watsonx.ai API method in `app/services/watsonx_service.py`
|
||||||
|
5. Register router in `app/main.py` using `app.include_router()`
|
||||||
|
|
||||||
|
### Streaming Responses
|
||||||
|
- Use `StreamingResponse` with `media_type="text/event-stream"`
|
||||||
|
- Format chunks as Server-Sent Events using `format_sse_event()`
|
||||||
|
- Always send `[DONE]` message at the end of stream
|
||||||
|
- Handle errors gracefully and send error events in SSE format
|
||||||
|
|
||||||
|
### Model Mapping
|
||||||
|
- Map OpenAI model names to watsonx models via environment variables
|
||||||
|
- Format: `MODEL_MAP_<OPENAI_MODEL>=<WATSONX_MODEL_ID>`
|
||||||
|
- Example: `MODEL_MAP_GPT4=ibm/granite-4-h-small`
|
||||||
|
- Mapping applied in `settings.map_model()` before API calls
|
||||||
|
|
||||||
|
### Security Considerations
|
||||||
|
- Optional API key authentication via `API_KEY` environment variable
|
||||||
|
- Middleware validates Bearer token in Authorization header
|
||||||
|
- IBM Cloud API key stored securely in environment variables
|
||||||
|
- CORS configured via `ALLOWED_ORIGINS` (default: `*`)
|
||||||
|
|
||||||
|
### Logging Best Practices
|
||||||
|
- Use structured logging with context (model names, request IDs)
|
||||||
|
- Log level controlled by `LOG_LEVEL` environment variable
|
||||||
|
- Log token refresh events at INFO level
|
||||||
|
- Log API errors at ERROR level with full traceback
|
||||||
|
- Include request/response details for debugging
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
- Keep `requirements.txt` minimal and pinned to specific versions
|
||||||
|
- FastAPI and Pydantic are core dependencies - avoid breaking changes
|
||||||
|
- httpx for async HTTP - prefer over requests/aiohttp
|
||||||
|
- Use `uvicorn[standard]` for production-ready server
|
||||||
|
|
||||||
|
## Important Implementation Notes
|
||||||
|
|
||||||
|
### watsonx.ai API Specifics
|
||||||
|
- Base URL format: `https://{cluster}.ml.cloud.ibm.com/ml/v1`
|
||||||
|
- API version parameter: `version=2024-02-13` (required on all requests)
|
||||||
|
- Chat endpoint: `/text/chat` (non-streaming) or `/text/chat_stream` (streaming)
|
||||||
|
- Text generation: `/text/generation`
|
||||||
|
- Embeddings: `/text/embeddings`
|
||||||
|
|
||||||
|
### Request/Response Transformation
|
||||||
|
- OpenAI messages → watsonx messages: Direct mapping with role/content
|
||||||
|
- watsonx responses → OpenAI format: Extract choices, usage, and metadata
|
||||||
|
- Streaming chunks: Parse SSE format, transform delta objects
|
||||||
|
- Generate unique IDs: `chatcmpl-{uuid}` for chat, `cmpl-{uuid}` for completions
|
||||||
|
|
||||||
|
### Common Pitfalls
|
||||||
|
- Don't forget to refresh tokens before they expire (50-minute interval)
|
||||||
|
- Always close httpx client on shutdown (`await watsonx_service.close()`)
|
||||||
|
- Handle both string and list formats for `stop` parameter
|
||||||
|
- Validate model IDs exist in watsonx.ai before making requests
|
||||||
|
- Set appropriate timeouts for long-running generation requests (300s default)
|
||||||
|
|
||||||
|
### Performance Optimization
|
||||||
|
- Reuse httpx client instance (don't create per request)
|
||||||
|
- Use connection pooling (httpx default behavior)
|
||||||
|
- Consider worker processes for production (`--workers 4`)
|
||||||
|
- Monitor token refresh to avoid rate limiting
|
||||||
|
|
||||||
|
## Environment Variables Reference
|
||||||
|
|
||||||
|
### Required
|
||||||
|
- `IBM_CLOUD_API_KEY`: IBM Cloud API key for authentication
|
||||||
|
- `WATSONX_PROJECT_ID`: watsonx.ai project ID
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
- `WATSONX_CLUSTER`: Region (default: `us-south`)
|
||||||
|
- `HOST`: Server host (default: `0.0.0.0`)
|
||||||
|
- `PORT`: Server port (default: `8000`)
|
||||||
|
- `LOG_LEVEL`: Logging level (default: `info`)
|
||||||
|
- `API_KEY`: Optional proxy authentication key
|
||||||
|
- `ALLOWED_ORIGINS`: CORS origins (default: `*`)
|
||||||
|
- `MODEL_MAP_*`: Model name mappings
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
- `GET /` - API information and available endpoints
|
||||||
|
- `GET /health` - Health check (bypasses authentication)
|
||||||
|
- `GET /docs` - Interactive Swagger UI documentation
|
||||||
|
- `POST /v1/chat/completions` - Chat completions (streaming supported)
|
||||||
|
- `POST /v1/completions` - Text completions (legacy)
|
||||||
|
- `POST /v1/embeddings` - Generate embeddings
|
||||||
|
- `GET /v1/models` - List available models
|
||||||
|
- `GET /v1/models/{model_id}` - Get specific model info
|
||||||
20
Dockerfile
Normal file
20
Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY app ./app
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD python -c "import httpx; httpx.get('http://localhost:8000/health')"
|
||||||
|
|
||||||
|
# Run the application
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
353
README.md
Normal file
353
README.md
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
# watsonx-openai-proxy
|
||||||
|
|
||||||
|
OpenAI-compatible API proxy for IBM watsonx.ai. This proxy allows you to use watsonx.ai models with any tool or application that supports the OpenAI API format.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- ✅ **Full OpenAI API Compatibility**: Drop-in replacement for OpenAI API
|
||||||
|
- ✅ **Chat Completions**: `/v1/chat/completions` with streaming support
|
||||||
|
- ✅ **Text Completions**: `/v1/completions` (legacy endpoint)
|
||||||
|
- ✅ **Embeddings**: `/v1/embeddings` for text embeddings
|
||||||
|
- ✅ **Model Listing**: `/v1/models` endpoint
|
||||||
|
- ✅ **Streaming Support**: Server-Sent Events (SSE) for real-time responses
|
||||||
|
- ✅ **Model Mapping**: Map OpenAI model names to watsonx models
|
||||||
|
- ✅ **Automatic Token Management**: Handles IBM Cloud authentication automatically
|
||||||
|
- ✅ **CORS Support**: Configurable cross-origin resource sharing
|
||||||
|
- ✅ **Optional API Key Authentication**: Secure your proxy with an API key
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.9 or higher
|
||||||
|
- IBM Cloud account with watsonx.ai access
|
||||||
|
- IBM Cloud API key
|
||||||
|
- watsonx.ai Project ID
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
1. Clone or download this directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd watsonx-openai-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Configure environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your credentials
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run the server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m app.main
|
||||||
|
```
|
||||||
|
|
||||||
|
Or with uvicorn:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
The server will start at `http://localhost:8000`
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
Create a `.env` file with the following variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Required: IBM Cloud Configuration
|
||||||
|
IBM_CLOUD_API_KEY=your_ibm_cloud_api_key_here
|
||||||
|
WATSONX_PROJECT_ID=your_watsonx_project_id_here
|
||||||
|
WATSONX_CLUSTER=us-south # Options: us-south, eu-de, eu-gb, jp-tok, au-syd, ca-tor
|
||||||
|
|
||||||
|
# Optional: Server Configuration
|
||||||
|
HOST=0.0.0.0
|
||||||
|
PORT=8000
|
||||||
|
LOG_LEVEL=info
|
||||||
|
|
||||||
|
# Optional: API Key for Proxy Authentication
|
||||||
|
API_KEY=your_optional_api_key_for_proxy_authentication
|
||||||
|
|
||||||
|
# Optional: CORS Configuration
|
||||||
|
ALLOWED_ORIGINS=* # Comma-separated or * for all
|
||||||
|
|
||||||
|
# Optional: Model Mapping
|
||||||
|
MODEL_MAP_GPT4=ibm/granite-4-h-small
|
||||||
|
MODEL_MAP_GPT35=ibm/granite-3-8b-instruct
|
||||||
|
MODEL_MAP_GPT4_TURBO=meta-llama/llama-3-3-70b-instruct
|
||||||
|
MODEL_MAP_TEXT_EMBEDDING_ADA_002=ibm/slate-125m-english-rtrvr
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Mapping
|
||||||
|
|
||||||
|
You can map OpenAI model names to watsonx models using environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
MODEL_MAP_<OPENAI_MODEL_NAME>=<WATSONX_MODEL_ID>
|
||||||
|
```
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- `MODEL_MAP_GPT4=ibm/granite-4-h-small` maps `gpt-4` to `ibm/granite-4-h-small`
|
||||||
|
- `MODEL_MAP_GPT35_TURBO=ibm/granite-3-8b-instruct` maps `gpt-3.5-turbo` to `ibm/granite-3-8b-instruct`
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### With OpenAI Python SDK
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Point to your proxy
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
api_key="your-proxy-api-key" # Optional, if you set API_KEY in .env
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use as normal
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="ibm/granite-3-8b-instruct", # Or use mapped name like "gpt-4"
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Hello, how are you?"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
### With Streaming
|
||||||
|
|
||||||
|
```python
|
||||||
|
stream = client.chat.completions.create(
|
||||||
|
model="ibm/granite-3-8b-instruct",
|
||||||
|
messages=[{"role": "user", "content": "Tell me a story"}],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
print(chunk.choices[0].delta.content, end="")
|
||||||
|
```
|
||||||
|
|
||||||
|
### With cURL
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer your-proxy-api-key" \
|
||||||
|
-d '{
|
||||||
|
"model": "ibm/granite-3-8b-instruct",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello!"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Embeddings
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="ibm/slate-125m-english-rtrvr",
|
||||||
|
input="Your text to embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.data[0].embedding)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Endpoints
|
||||||
|
|
||||||
|
- `GET /` - API information
|
||||||
|
- `GET /health` - Health check
|
||||||
|
- `GET /docs` - Interactive API documentation (Swagger UI)
|
||||||
|
- `POST /v1/chat/completions` - Chat completions
|
||||||
|
- `POST /v1/completions` - Text completions (legacy)
|
||||||
|
- `POST /v1/embeddings` - Generate embeddings
|
||||||
|
- `GET /v1/models` - List available models
|
||||||
|
- `GET /v1/models/{model_id}` - Get model information
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
The proxy supports all watsonx.ai models available in your project, including:
|
||||||
|
|
||||||
|
### Chat Models
|
||||||
|
- IBM Granite models (3.x, 4.x series)
|
||||||
|
- Meta Llama models (3.x, 4.x series)
|
||||||
|
- Mistral models
|
||||||
|
- Other models available on watsonx.ai
|
||||||
|
|
||||||
|
### Embedding Models
|
||||||
|
- `ibm/slate-125m-english-rtrvr`
|
||||||
|
- `ibm/slate-30m-english-rtrvr`
|
||||||
|
|
||||||
|
See `/v1/models` endpoint for the complete list.
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
### Proxy Authentication (Optional)
|
||||||
|
|
||||||
|
If you set `API_KEY` in your `.env` file, clients must provide it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
api_key="your-proxy-api-key"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### IBM Cloud Authentication
|
||||||
|
|
||||||
|
The proxy handles IBM Cloud authentication automatically using your `IBM_CLOUD_API_KEY`. Bearer tokens are:
|
||||||
|
- Automatically obtained on startup
|
||||||
|
- Refreshed every 50 minutes (tokens expire after 60 minutes)
|
||||||
|
- Refreshed on 401 errors
|
||||||
|
|
||||||
|
## Deployment
|
||||||
|
|
||||||
|
### Docker (Recommended)
|
||||||
|
|
||||||
|
Create a `Dockerfile`:
|
||||||
|
|
||||||
|
```dockerfile
|
||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
COPY app ./app
|
||||||
|
COPY .env .
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Build and run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -t watsonx-openai-proxy .
|
||||||
|
docker run -p 8000:8000 --env-file .env watsonx-openai-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Production Deployment
|
||||||
|
|
||||||
|
For production, consider:
|
||||||
|
|
||||||
|
1. **Use a production ASGI server**: The included uvicorn is suitable, but configure workers:
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Set up HTTPS**: Use a reverse proxy like nginx or Caddy
|
||||||
|
|
||||||
|
3. **Configure CORS**: Set `ALLOWED_ORIGINS` to specific domains
|
||||||
|
|
||||||
|
4. **Enable API key authentication**: Set `API_KEY` in environment
|
||||||
|
|
||||||
|
5. **Monitor logs**: Set `LOG_LEVEL=info` or `warning` in production
|
||||||
|
|
||||||
|
6. **Use environment secrets**: Don't commit `.env` file, use secret management
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### 401 Unauthorized
|
||||||
|
|
||||||
|
- Check that `IBM_CLOUD_API_KEY` is valid
|
||||||
|
- Verify your IBM Cloud account has watsonx.ai access
|
||||||
|
- Check server logs for token refresh errors
|
||||||
|
|
||||||
|
### Model Not Found
|
||||||
|
|
||||||
|
- Verify the model ID exists in watsonx.ai
|
||||||
|
- Check that your project has access to the model
|
||||||
|
- Use `/v1/models` endpoint to see available models
|
||||||
|
|
||||||
|
### Connection Errors
|
||||||
|
|
||||||
|
- Verify `WATSONX_CLUSTER` matches your project's region
|
||||||
|
- Check firewall/network settings
|
||||||
|
- Ensure watsonx.ai services are accessible
|
||||||
|
|
||||||
|
### Streaming Issues
|
||||||
|
|
||||||
|
- Some models may not support streaming
|
||||||
|
- Check client library supports SSE (Server-Sent Events)
|
||||||
|
- Verify network doesn't buffer streaming responses
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dev dependencies
|
||||||
|
pip install pytest pytest-asyncio httpx
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
pytest tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
watsonx-openai-proxy/
|
||||||
|
├── app/
|
||||||
|
│ ├── main.py # FastAPI application
|
||||||
|
│ ├── config.py # Configuration management
|
||||||
|
│ ├── routers/ # API endpoint routers
|
||||||
|
│ │ ├── chat.py # Chat completions
|
||||||
|
│ │ ├── completions.py # Text completions
|
||||||
|
│ │ ├── embeddings.py # Embeddings
|
||||||
|
│ │ └── models.py # Model listing
|
||||||
|
│ ├── services/ # Business logic
|
||||||
|
│ │ └── watsonx_service.py # watsonx.ai API client
|
||||||
|
│ ├── models/ # Pydantic models
|
||||||
|
│ │ └── openai_models.py # OpenAI-compatible schemas
|
||||||
|
│ └── utils/ # Utilities
|
||||||
|
│ └── transformers.py # Request/response transformers
|
||||||
|
├── tests/ # Test files
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
├── .env.example # Environment template
|
||||||
|
└── README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome! Please:
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a feature branch
|
||||||
|
3. Make your changes
|
||||||
|
4. Add tests if applicable
|
||||||
|
5. Submit a pull request
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Apache 2.0 License - See LICENSE file for details.
|
||||||
|
|
||||||
|
## Related Projects
|
||||||
|
|
||||||
|
- [watsonx-unofficial-aisdk-provider](../wxai-provider/) - Vercel AI SDK provider for watsonx.ai
|
||||||
|
- [OpenCode watsonx plugin](../.opencode/plugins/) - Token management plugin for OpenCode
|
||||||
|
|
||||||
|
## Disclaimer
|
||||||
|
|
||||||
|
This is **not an official IBM product**. It's a community-maintained proxy for integrating watsonx.ai with OpenAI-compatible tools. watsonx.ai is a trademark of IBM.
|
||||||
|
|
||||||
|
## Support
|
||||||
|
|
||||||
|
For issues and questions:
|
||||||
|
- Check the [Troubleshooting](#troubleshooting) section
|
||||||
|
- Review server logs (`LOG_LEVEL=debug` for detailed logs)
|
||||||
|
- Open an issue in the repository
|
||||||
|
- Consult [IBM watsonx.ai documentation](https://www.ibm.com/docs/en/watsonx-as-a-service)
|
||||||
3
app/__init__.py
Normal file
3
app/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""watsonx-openai-proxy application package."""
|
||||||
|
|
||||||
|
__version__ = "1.0.0"
|
||||||
91
app/config.py
Normal file
91
app/config.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""Configuration management for watsonx-openai-proxy."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment variables."""
|
||||||
|
|
||||||
|
# IBM Cloud Configuration
|
||||||
|
ibm_cloud_api_key: str
|
||||||
|
watsonx_project_id: str
|
||||||
|
watsonx_cluster: str = "us-south"
|
||||||
|
|
||||||
|
# Server Configuration
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8000
|
||||||
|
log_level: str = "info"
|
||||||
|
|
||||||
|
# API Configuration
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
allowed_origins: str = "*"
|
||||||
|
|
||||||
|
# Token Management
|
||||||
|
token_refresh_interval: int = 3000 # 50 minutes in seconds
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
case_sensitive=False,
|
||||||
|
extra="allow", # Allow extra fields for model mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def watsonx_base_url(self) -> str:
|
||||||
|
"""Construct the watsonx.ai base URL from cluster."""
|
||||||
|
return f"https://{self.watsonx_cluster}.ml.cloud.ibm.com/ml/v1"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cors_origins(self) -> list[str]:
|
||||||
|
"""Parse CORS origins from comma-separated string."""
|
||||||
|
if self.allowed_origins == "*":
|
||||||
|
return ["*"]
|
||||||
|
return [origin.strip() for origin in self.allowed_origins.split(",")]
|
||||||
|
|
||||||
|
def get_model_mapping(self) -> Dict[str, str]:
|
||||||
|
"""Extract model mappings from environment variables.
|
||||||
|
|
||||||
|
Looks for variables like MODEL_MAP_GPT4=ibm/granite-4-h-small
|
||||||
|
and creates a mapping dict.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
mapping = {}
|
||||||
|
|
||||||
|
# Check os.environ first
|
||||||
|
for key, value in os.environ.items():
|
||||||
|
if key.startswith("MODEL_MAP_"):
|
||||||
|
model_name = key.replace("MODEL_MAP_", "")
|
||||||
|
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
||||||
|
openai_model = model_name.lower().replace("_", "-")
|
||||||
|
mapping[openai_model] = value
|
||||||
|
|
||||||
|
# Also check pydantic's extra fields (from .env file)
|
||||||
|
# These come in as lowercase: model_map_gpt4 instead of MODEL_MAP_GPT4
|
||||||
|
extra = getattr(self, '__pydantic_extra__', {}) or {}
|
||||||
|
for key, value in extra.items():
|
||||||
|
if key.startswith("model_map_"):
|
||||||
|
# Convert back to uppercase for processing
|
||||||
|
model_name = key.replace("model_map_", "").upper()
|
||||||
|
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
||||||
|
openai_model = model_name.lower().replace("_", "-")
|
||||||
|
mapping[openai_model] = value
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def map_model(self, openai_model: str) -> str:
|
||||||
|
"""Map an OpenAI model name to a watsonx model ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
openai_model: OpenAI model name (e.g., "gpt-4", "gpt-3.5-turbo")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Corresponding watsonx model ID, or the original name if no mapping exists
|
||||||
|
"""
|
||||||
|
model_map = self.get_model_mapping()
|
||||||
|
return model_map.get(openai_model, openai_model)
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
|
settings = Settings()
|
||||||
157
app/main.py
Normal file
157
app/main.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Main FastAPI application for watsonx-openai-proxy."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, Request, status
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from app.config import settings
|
||||||
|
from app.routers import chat, completions, embeddings, models
|
||||||
|
from app.services.watsonx_service import watsonx_service
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, settings.log_level.upper()),
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Manage application lifespan events."""
|
||||||
|
# Startup
|
||||||
|
logger.info("Starting watsonx-openai-proxy...")
|
||||||
|
logger.info(f"Cluster: {settings.watsonx_cluster}")
|
||||||
|
logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...")
|
||||||
|
|
||||||
|
# Initialize token
|
||||||
|
try:
|
||||||
|
await watsonx_service._refresh_token()
|
||||||
|
logger.info("Initial bearer token obtained successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to obtain initial bearer token: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
logger.info("Shutting down watsonx-openai-proxy...")
|
||||||
|
await watsonx_service.close()
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI(
|
||||||
|
title="watsonx-openai-proxy",
|
||||||
|
description="OpenAI-compatible API proxy for IBM watsonx.ai",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Optional API key authentication middleware
|
||||||
|
@app.middleware("http")
|
||||||
|
async def authenticate(request: Request, call_next):
|
||||||
|
"""Authenticate requests if API key is configured."""
|
||||||
|
if settings.api_key:
|
||||||
|
# Skip authentication for health check
|
||||||
|
if request.url.path == "/health":
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Check Authorization header
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if not auth_header:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": "Missing Authorization header",
|
||||||
|
"type": "authentication_error",
|
||||||
|
"code": "missing_authorization",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate API key
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": "Invalid Authorization header format",
|
||||||
|
"type": "authentication_error",
|
||||||
|
"code": "invalid_authorization_format",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
token = auth_header[7:] # Remove "Bearer " prefix
|
||||||
|
if token != settings.api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": "Invalid API key",
|
||||||
|
"type": "authentication_error",
|
||||||
|
"code": "invalid_api_key",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
# Include routers
|
||||||
|
app.include_router(chat.router, tags=["Chat"])
|
||||||
|
app.include_router(completions.router, tags=["Completions"])
|
||||||
|
app.include_router(embeddings.router, tags=["Embeddings"])
|
||||||
|
app.include_router(models.router, tags=["Models"])
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "watsonx-openai-proxy",
|
||||||
|
"cluster": settings.watsonx_cluster,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint with API information."""
|
||||||
|
return {
|
||||||
|
"service": "watsonx-openai-proxy",
|
||||||
|
"description": "OpenAI-compatible API proxy for IBM watsonx.ai",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"endpoints": {
|
||||||
|
"chat": "/v1/chat/completions",
|
||||||
|
"completions": "/v1/completions",
|
||||||
|
"embeddings": "/v1/embeddings",
|
||||||
|
"models": "/v1/models",
|
||||||
|
"health": "/health",
|
||||||
|
},
|
||||||
|
"documentation": "/docs",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"app.main:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
log_level=settings.log_level,
|
||||||
|
reload=False,
|
||||||
|
)
|
||||||
31
app/models/__init__.py
Normal file
31
app/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""OpenAI-compatible data models."""
|
||||||
|
|
||||||
|
from app.models.openai_models import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
EmbeddingRequest,
|
||||||
|
EmbeddingResponse,
|
||||||
|
ModelsResponse,
|
||||||
|
ModelInfo,
|
||||||
|
ErrorResponse,
|
||||||
|
ErrorDetail,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatCompletionRequest",
|
||||||
|
"ChatCompletionResponse",
|
||||||
|
"ChatCompletionChunk",
|
||||||
|
"CompletionRequest",
|
||||||
|
"CompletionResponse",
|
||||||
|
"EmbeddingRequest",
|
||||||
|
"EmbeddingResponse",
|
||||||
|
"ModelsResponse",
|
||||||
|
"ModelInfo",
|
||||||
|
"ErrorResponse",
|
||||||
|
"ErrorDetail",
|
||||||
|
]
|
||||||
213
app/models/openai_models.py
Normal file
213
app/models/openai_models.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""OpenAI-compatible request and response models."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Chat Completions Models
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
"""A chat message in the conversation."""
|
||||||
|
role: Literal["system", "user", "assistant", "function", "tool"]
|
||||||
|
content: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
function_call: Optional[Dict[str, Any]] = None
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(BaseModel):
|
||||||
|
"""Function call specification."""
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
"""Tool call specification."""
|
||||||
|
id: str
|
||||||
|
type: Literal["function"]
|
||||||
|
function: FunctionCall
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
"""OpenAI chat completion request."""
|
||||||
|
model: str
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
|
||||||
|
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
|
||||||
|
n: Optional[int] = Field(default=1, ge=1)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
max_tokens: Optional[int] = Field(default=None, ge=1)
|
||||||
|
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||||
|
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None
|
||||||
|
function_call: Optional[Union[str, Dict[str, str]]] = None
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChoice(BaseModel):
|
||||||
|
"""A single chat completion choice."""
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
logprobs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionUsage(BaseModel):
|
||||||
|
"""Token usage information."""
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
"""OpenAI chat completion response."""
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionChoice]
|
||||||
|
usage: ChatCompletionUsage
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunkDelta(BaseModel):
|
||||||
|
"""Delta content in streaming response."""
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
function_call: Optional[Dict[str, Any]] = None
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunkChoice(BaseModel):
|
||||||
|
"""A single streaming chunk choice."""
|
||||||
|
index: int
|
||||||
|
delta: ChatCompletionChunkDelta
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
logprobs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(BaseModel):
|
||||||
|
"""OpenAI streaming chat completion chunk."""
|
||||||
|
id: str
|
||||||
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionChunkChoice]
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Completions Models (Legacy)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
"""OpenAI completion request (legacy)."""
|
||||||
|
model: str
|
||||||
|
prompt: Union[str, List[str], List[int], List[List[int]]]
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
max_tokens: Optional[int] = Field(default=16, ge=1)
|
||||||
|
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
|
||||||
|
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
|
||||||
|
n: Optional[int] = Field(default=1, ge=1)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
logprobs: Optional[int] = Field(default=None, ge=0, le=5)
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||||
|
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||||
|
best_of: Optional[int] = Field(default=1, ge=1)
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionChoice(BaseModel):
|
||||||
|
"""A single completion choice."""
|
||||||
|
text: str
|
||||||
|
index: int
|
||||||
|
logprobs: Optional[Dict[str, Any]] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
"""OpenAI completion response."""
|
||||||
|
id: str
|
||||||
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionChoice]
|
||||||
|
usage: ChatCompletionUsage
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Embeddings Models
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
"""OpenAI embedding request."""
|
||||||
|
model: str
|
||||||
|
input: Union[str, List[str], List[int], List[List[int]]]
|
||||||
|
encoding_format: Optional[Literal["float", "base64"]] = "float"
|
||||||
|
dimensions: Optional[int] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingData(BaseModel):
|
||||||
|
"""A single embedding result."""
|
||||||
|
object: Literal["embedding"] = "embedding"
|
||||||
|
embedding: List[float]
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingUsage(BaseModel):
|
||||||
|
"""Token usage for embeddings."""
|
||||||
|
prompt_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(BaseModel):
|
||||||
|
"""OpenAI embedding response."""
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: List[EmbeddingData]
|
||||||
|
model: str
|
||||||
|
usage: EmbeddingUsage
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Models List
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
"""Information about a single model."""
|
||||||
|
id: str
|
||||||
|
object: Literal["model"] = "model"
|
||||||
|
created: int
|
||||||
|
owned_by: str
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsResponse(BaseModel):
|
||||||
|
"""List of available models."""
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: List[ModelInfo]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Error Models
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class ErrorDetail(BaseModel):
|
||||||
|
"""Error detail information."""
|
||||||
|
message: str
|
||||||
|
type: str
|
||||||
|
param: Optional[str] = None
|
||||||
|
code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""OpenAI error response."""
|
||||||
|
error: ErrorDetail
|
||||||
5
app/routers/__init__.py
Normal file
5
app/routers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""API routers for watsonx-openai-proxy."""
|
||||||
|
|
||||||
|
from app.routers import chat, completions, embeddings, models
|
||||||
|
|
||||||
|
__all__ = ["chat", "completions", "embeddings", "models"]
|
||||||
156
app/routers/chat.py
Normal file
156
app/routers/chat.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Chat completions endpoint router."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Union
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from app.models.openai_models import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ErrorDetail,
|
||||||
|
)
|
||||||
|
from app.services.watsonx_service import watsonx_service
|
||||||
|
from app.utils.transformers import (
|
||||||
|
transform_messages_to_watsonx,
|
||||||
|
transform_tools_to_watsonx,
|
||||||
|
transform_watsonx_to_openai_chat,
|
||||||
|
transform_watsonx_to_openai_chat_chunk,
|
||||||
|
format_sse_event,
|
||||||
|
)
|
||||||
|
from app.config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
response_model=Union[ChatCompletionResponse, ErrorResponse],
|
||||||
|
responses={
|
||||||
|
200: {"model": ChatCompletionResponse},
|
||||||
|
400: {"model": ErrorResponse},
|
||||||
|
401: {"model": ErrorResponse},
|
||||||
|
500: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_chat_completion(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
http_request: Request,
|
||||||
|
):
|
||||||
|
"""Create a chat completion using OpenAI-compatible API.
|
||||||
|
|
||||||
|
This endpoint accepts OpenAI-formatted requests and translates them
|
||||||
|
to watsonx.ai API calls.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Map model name if needed
|
||||||
|
watsonx_model = settings.map_model(request.model)
|
||||||
|
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
|
||||||
|
|
||||||
|
# Transform messages
|
||||||
|
watsonx_messages = transform_messages_to_watsonx(request.messages)
|
||||||
|
|
||||||
|
# Transform tools if present
|
||||||
|
watsonx_tools = transform_tools_to_watsonx(request.tools)
|
||||||
|
|
||||||
|
# Handle streaming
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_chat_completion(
|
||||||
|
watsonx_model,
|
||||||
|
watsonx_messages,
|
||||||
|
request,
|
||||||
|
watsonx_tools,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-streaming response
|
||||||
|
watsonx_response = await watsonx_service.chat_completion(
|
||||||
|
model_id=watsonx_model,
|
||||||
|
messages=watsonx_messages,
|
||||||
|
temperature=request.temperature or 1.0,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
top_p=request.top_p or 1.0,
|
||||||
|
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||||
|
tools=watsonx_tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform response
|
||||||
|
openai_response = transform_watsonx_to_openai_chat(
|
||||||
|
watsonx_response,
|
||||||
|
request.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return openai_response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": str(e),
|
||||||
|
"type": "internal_error",
|
||||||
|
"code": "internal_error",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_chat_completion(
|
||||||
|
watsonx_model: str,
|
||||||
|
watsonx_messages: list,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
watsonx_tools: list = None,
|
||||||
|
):
|
||||||
|
"""Stream chat completion responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
watsonx_model: The watsonx model ID
|
||||||
|
watsonx_messages: Transformed messages
|
||||||
|
request: Original OpenAI request
|
||||||
|
watsonx_tools: Transformed tools
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Server-Sent Events with chat completion chunks
|
||||||
|
"""
|
||||||
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in watsonx_service.chat_completion_stream(
|
||||||
|
model_id=watsonx_model,
|
||||||
|
messages=watsonx_messages,
|
||||||
|
temperature=request.temperature or 1.0,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
top_p=request.top_p or 1.0,
|
||||||
|
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||||
|
tools=watsonx_tools,
|
||||||
|
):
|
||||||
|
# Transform chunk to OpenAI format
|
||||||
|
openai_chunk = transform_watsonx_to_openai_chat_chunk(
|
||||||
|
chunk,
|
||||||
|
request.model,
|
||||||
|
request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send as SSE
|
||||||
|
yield format_sse_event(openai_chunk.model_dump_json())
|
||||||
|
|
||||||
|
# Send [DONE] message
|
||||||
|
yield format_sse_event("[DONE]")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
|
||||||
|
error_response = ErrorResponse(
|
||||||
|
error=ErrorDetail(
|
||||||
|
message=str(e),
|
||||||
|
type="internal_error",
|
||||||
|
code="stream_error",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield format_sse_event(error_response.model_dump_json())
|
||||||
109
app/routers/completions.py
Normal file
109
app/routers/completions.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""Text completions endpoint router (legacy)."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Union
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from app.models.openai_models import (
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ErrorDetail,
|
||||||
|
)
|
||||||
|
from app.services.watsonx_service import watsonx_service
|
||||||
|
from app.utils.transformers import transform_watsonx_to_openai_completion
|
||||||
|
from app.config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/completions",
|
||||||
|
response_model=Union[CompletionResponse, ErrorResponse],
|
||||||
|
responses={
|
||||||
|
200: {"model": CompletionResponse},
|
||||||
|
400: {"model": ErrorResponse},
|
||||||
|
401: {"model": ErrorResponse},
|
||||||
|
500: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_completion(
|
||||||
|
request: CompletionRequest,
|
||||||
|
http_request: Request,
|
||||||
|
):
|
||||||
|
"""Create a text completion using OpenAI-compatible API (legacy).
|
||||||
|
|
||||||
|
This endpoint accepts OpenAI-formatted completion requests and translates
|
||||||
|
them to watsonx.ai text generation API calls.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Map model name if needed
|
||||||
|
watsonx_model = settings.map_model(request.model)
|
||||||
|
logger.info(f"Completion request: {request.model} -> {watsonx_model}")
|
||||||
|
|
||||||
|
# Handle prompt (can be string or list)
|
||||||
|
if isinstance(request.prompt, list):
|
||||||
|
if len(request.prompt) == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": "Prompt cannot be empty",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "invalid_prompt",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# For now, just use the first prompt
|
||||||
|
# TODO: Handle multiple prompts with n parameter
|
||||||
|
prompt = request.prompt[0] if isinstance(request.prompt[0], str) else ""
|
||||||
|
else:
|
||||||
|
prompt = request.prompt
|
||||||
|
|
||||||
|
# Note: Streaming not implemented for completions yet
|
||||||
|
if request.stream:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": "Streaming not supported for completions endpoint",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "streaming_not_supported",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call watsonx text generation
|
||||||
|
watsonx_response = await watsonx_service.text_generation(
|
||||||
|
model_id=watsonx_model,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=request.temperature or 1.0,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
top_p=request.top_p or 1.0,
|
||||||
|
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform response
|
||||||
|
openai_response = transform_watsonx_to_openai_completion(
|
||||||
|
watsonx_response,
|
||||||
|
request.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return openai_response
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in completion: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": str(e),
|
||||||
|
"type": "internal_error",
|
||||||
|
"code": "internal_error",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
114
app/routers/embeddings.py
Normal file
114
app/routers/embeddings.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""Embeddings endpoint router."""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from app.models.openai_models import (
|
||||||
|
EmbeddingRequest,
|
||||||
|
EmbeddingResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
ErrorDetail,
|
||||||
|
)
|
||||||
|
from app.services.watsonx_service import watsonx_service
|
||||||
|
from app.utils.transformers import transform_watsonx_to_openai_embeddings
|
||||||
|
from app.config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/embeddings",
|
||||||
|
response_model=Union[EmbeddingResponse, ErrorResponse],
|
||||||
|
responses={
|
||||||
|
200: {"model": EmbeddingResponse},
|
||||||
|
400: {"model": ErrorResponse},
|
||||||
|
401: {"model": ErrorResponse},
|
||||||
|
500: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_embeddings(
|
||||||
|
request: EmbeddingRequest,
|
||||||
|
http_request: Request,
|
||||||
|
):
|
||||||
|
"""Create embeddings using OpenAI-compatible API.
|
||||||
|
|
||||||
|
This endpoint accepts OpenAI-formatted embedding requests and translates
|
||||||
|
them to watsonx.ai embeddings API calls.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Map model name if needed
|
||||||
|
watsonx_model = settings.map_model(request.model)
|
||||||
|
logger.info(f"Embeddings request: {request.model} -> {watsonx_model}")
|
||||||
|
|
||||||
|
# Handle input (can be string or list)
|
||||||
|
if isinstance(request.input, str):
|
||||||
|
inputs = [request.input]
|
||||||
|
elif isinstance(request.input, list):
|
||||||
|
if len(request.input) == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": "Input cannot be empty",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "invalid_input",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Handle list of strings or list of token IDs
|
||||||
|
if isinstance(request.input[0], str):
|
||||||
|
inputs = request.input
|
||||||
|
else:
|
||||||
|
# Token IDs not supported yet
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": "Token ID input not supported",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "unsupported_input_type",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": "Invalid input type",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "invalid_input_type",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call watsonx embeddings
|
||||||
|
watsonx_response = await watsonx_service.embeddings(
|
||||||
|
model_id=watsonx_model,
|
||||||
|
inputs=inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform response
|
||||||
|
openai_response = transform_watsonx_to_openai_embeddings(
|
||||||
|
watsonx_response,
|
||||||
|
request.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return openai_response
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in embeddings: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": str(e),
|
||||||
|
"type": "internal_error",
|
||||||
|
"code": "internal_error",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
120
app/routers/models.py
Normal file
120
app/routers/models.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Models endpoint router."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from app.models.openai_models import ModelsResponse, ModelInfo
|
||||||
|
from app.config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# Predefined list of available models
|
||||||
|
# This can be extended or made dynamic based on watsonx.ai API
|
||||||
|
AVAILABLE_MODELS = [
|
||||||
|
# Granite Models
|
||||||
|
"ibm/granite-3-1-8b-base",
|
||||||
|
"ibm/granite-3-2-8b-instruct",
|
||||||
|
"ibm/granite-3-3-8b-instruct",
|
||||||
|
"ibm/granite-3-8b-instruct",
|
||||||
|
"ibm/granite-4-h-small",
|
||||||
|
"ibm/granite-8b-code-instruct",
|
||||||
|
|
||||||
|
# Llama Models
|
||||||
|
"meta-llama/llama-3-1-70b-gptq",
|
||||||
|
"meta-llama/llama-3-1-8b",
|
||||||
|
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||||
|
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||||
|
"meta-llama/llama-3-3-70b-instruct",
|
||||||
|
"meta-llama/llama-3-405b-instruct",
|
||||||
|
"meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
|
||||||
|
|
||||||
|
# Mistral Models
|
||||||
|
"mistral-large-2512",
|
||||||
|
"mistralai/mistral-medium-2505",
|
||||||
|
"mistralai/mistral-small-3-1-24b-instruct-2503",
|
||||||
|
|
||||||
|
# Other Models
|
||||||
|
"openai/gpt-oss-120b",
|
||||||
|
|
||||||
|
# Embedding Models
|
||||||
|
"ibm/slate-125m-english-rtrvr",
|
||||||
|
"ibm/slate-30m-english-rtrvr",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/models",
|
||||||
|
response_model=ModelsResponse,
|
||||||
|
)
|
||||||
|
async def list_models():
|
||||||
|
"""List available models in OpenAI-compatible format.
|
||||||
|
|
||||||
|
Returns a list of models that can be used with the API.
|
||||||
|
Includes both the actual watsonx model IDs and any mapped names.
|
||||||
|
"""
|
||||||
|
created_time = int(time.time())
|
||||||
|
models = []
|
||||||
|
|
||||||
|
# Add all available watsonx models
|
||||||
|
for model_id in AVAILABLE_MODELS:
|
||||||
|
models.append(
|
||||||
|
ModelInfo(
|
||||||
|
id=model_id,
|
||||||
|
created=created_time,
|
||||||
|
owned_by="ibm-watsonx",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add mapped model names (e.g., gpt-4 -> ibm/granite-4-h-small)
|
||||||
|
model_mapping = settings.get_model_mapping()
|
||||||
|
for openai_name, watsonx_id in model_mapping.items():
|
||||||
|
if watsonx_id in AVAILABLE_MODELS:
|
||||||
|
models.append(
|
||||||
|
ModelInfo(
|
||||||
|
id=openai_name,
|
||||||
|
created=created_time,
|
||||||
|
owned_by="ibm-watsonx",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelsResponse(data=models)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/models/{model_id}",
|
||||||
|
response_model=ModelInfo,
|
||||||
|
)
|
||||||
|
async def retrieve_model(model_id: str):
|
||||||
|
"""Retrieve information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model ID to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model information
|
||||||
|
"""
|
||||||
|
# Map the model if needed
|
||||||
|
watsonx_model = settings.map_model(model_id)
|
||||||
|
|
||||||
|
# Check if model exists
|
||||||
|
if watsonx_model not in AVAILABLE_MODELS:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": f"Model '{model_id}' not found",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"code": "model_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
id=model_id,
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="ibm-watsonx",
|
||||||
|
)
|
||||||
5
app/services/__init__.py
Normal file
5
app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Services for watsonx-openai-proxy."""
|
||||||
|
|
||||||
|
from app.services.watsonx_service import watsonx_service, WatsonxService
|
||||||
|
|
||||||
|
__all__ = ["watsonx_service", "WatsonxService"]
|
||||||
316
app/services/watsonx_service.py
Normal file
316
app/services/watsonx_service.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""Service for interacting with IBM watsonx.ai APIs."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import AsyncIterator, Dict, List, Optional
|
||||||
|
import httpx
|
||||||
|
from app.config import settings
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonxService:
|
||||||
|
"""Service for managing watsonx.ai API interactions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.base_url = settings.watsonx_base_url
|
||||||
|
self.project_id = settings.watsonx_project_id
|
||||||
|
self.api_key = settings.ibm_cloud_api_key
|
||||||
|
self._bearer_token: Optional[str] = None
|
||||||
|
self._token_expiry: Optional[float] = None
|
||||||
|
self._token_lock = asyncio.Lock()
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get or create HTTP client."""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(timeout=300.0)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def _refresh_token(self) -> str:
|
||||||
|
"""Get a fresh bearer token from IBM Cloud IAM."""
|
||||||
|
async with self._token_lock:
|
||||||
|
# Check if token is still valid
|
||||||
|
if self._bearer_token and self._token_expiry:
|
||||||
|
if time.time() < self._token_expiry - 300: # 5 min buffer
|
||||||
|
return self._bearer_token
|
||||||
|
|
||||||
|
logger.info("Refreshing IBM Cloud bearer token...")
|
||||||
|
|
||||||
|
client = await self._get_client()
|
||||||
|
response = await client.post(
|
||||||
|
"https://iam.cloud.ibm.com/identity/token",
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Failed to get bearer token: {response.text}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
self._bearer_token = data["access_token"]
|
||||||
|
self._token_expiry = time.time() + data.get("expires_in", 3600)
|
||||||
|
|
||||||
|
logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds")
|
||||||
|
return self._bearer_token
|
||||||
|
|
||||||
|
async def _get_headers(self) -> Dict[str, str]:
|
||||||
|
"""Get headers with valid bearer token."""
|
||||||
|
token = await self._refresh_token()
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict:
|
||||||
|
"""Create a chat completion using watsonx.ai.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The watsonx model ID
|
||||||
|
messages: List of chat messages
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
top_p: Nucleus sampling parameter
|
||||||
|
stop: Stop sequences
|
||||||
|
stream: Whether to stream the response
|
||||||
|
tools: Tool/function definitions
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chat completion response
|
||||||
|
"""
|
||||||
|
headers = await self._get_headers()
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
# Build watsonx request
|
||||||
|
payload = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"messages": messages,
|
||||||
|
"parameters": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_tokens:
|
||||||
|
payload["parameters"]["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = tools
|
||||||
|
|
||||||
|
url = f"{self.base_url}/text/chat"
|
||||||
|
params = {"version": "2024-02-13"}
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
params["stream"] = "true"
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def chat_completion_stream(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Dict],
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncIterator[Dict]:
|
||||||
|
"""Stream a chat completion using watsonx.ai.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The watsonx model ID
|
||||||
|
messages: List of chat messages
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
top_p: Nucleus sampling parameter
|
||||||
|
stop: Stop sequences
|
||||||
|
tools: Tool/function definitions
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Chat completion chunks
|
||||||
|
"""
|
||||||
|
headers = await self._get_headers()
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
# Build watsonx request
|
||||||
|
payload = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"messages": messages,
|
||||||
|
"parameters": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_tokens:
|
||||||
|
payload["parameters"]["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = tools
|
||||||
|
|
||||||
|
url = f"{self.base_url}/text/chat_stream"
|
||||||
|
params = {"version": "2024-02-13"}
|
||||||
|
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
params=params,
|
||||||
|
) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
text = await response.aread()
|
||||||
|
raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}")
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
data = line[6:] # Remove "data: " prefix
|
||||||
|
if data.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
yield json.loads(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def text_generation(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
prompt: str,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict:
|
||||||
|
"""Generate text completion using watsonx.ai.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The watsonx model ID
|
||||||
|
prompt: The input prompt
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_tokens: Maximum tokens to generate
|
||||||
|
top_p: Nucleus sampling parameter
|
||||||
|
stop: Stop sequences
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text generation response
|
||||||
|
"""
|
||||||
|
headers = await self._get_headers()
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"input": prompt,
|
||||||
|
"parameters": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_tokens:
|
||||||
|
payload["parameters"]["max_new_tokens"] = max_tokens
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||||
|
|
||||||
|
url = f"{self.base_url}/text/generation"
|
||||||
|
params = {"version": "2024-02-13"}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
inputs: List[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict:
|
||||||
|
"""Generate embeddings using watsonx.ai.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The watsonx embedding model ID
|
||||||
|
inputs: List of texts to embed
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings response
|
||||||
|
"""
|
||||||
|
headers = await self._get_headers()
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model_id": model_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"inputs": inputs,
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self.base_url}/text/embeddings"
|
||||||
|
params = {"version": "2024-02-13"}
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
# Global service instance
|
||||||
|
watsonx_service = WatsonxService()
|
||||||
21
app/utils/__init__.py
Normal file
21
app/utils/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Utility functions for watsonx-openai-proxy."""
|
||||||
|
|
||||||
|
from app.utils.transformers import (
|
||||||
|
transform_messages_to_watsonx,
|
||||||
|
transform_tools_to_watsonx,
|
||||||
|
transform_watsonx_to_openai_chat,
|
||||||
|
transform_watsonx_to_openai_chat_chunk,
|
||||||
|
transform_watsonx_to_openai_completion,
|
||||||
|
transform_watsonx_to_openai_embeddings,
|
||||||
|
format_sse_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"transform_messages_to_watsonx",
|
||||||
|
"transform_tools_to_watsonx",
|
||||||
|
"transform_watsonx_to_openai_chat",
|
||||||
|
"transform_watsonx_to_openai_chat_chunk",
|
||||||
|
"transform_watsonx_to_openai_completion",
|
||||||
|
"transform_watsonx_to_openai_embeddings",
|
||||||
|
"format_sse_event",
|
||||||
|
]
|
||||||
272
app/utils/transformers.py
Normal file
272
app/utils/transformers.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
"""Utilities for transforming between OpenAI and watsonx formats."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from app.models.openai_models import (
|
||||||
|
ChatMessage,
|
||||||
|
ChatCompletionChoice,
|
||||||
|
ChatCompletionUsage,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionChunk,
|
||||||
|
ChatCompletionChunkChoice,
|
||||||
|
ChatCompletionChunkDelta,
|
||||||
|
CompletionChoice,
|
||||||
|
CompletionResponse,
|
||||||
|
EmbeddingData,
|
||||||
|
EmbeddingUsage,
|
||||||
|
EmbeddingResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_messages_to_watsonx(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||||
|
"""Transform OpenAI messages to watsonx format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of OpenAI ChatMessage objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of watsonx-compatible message dicts
|
||||||
|
"""
|
||||||
|
watsonx_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
watsonx_msg = {
|
||||||
|
"role": msg.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.content:
|
||||||
|
watsonx_msg["content"] = msg.content
|
||||||
|
|
||||||
|
if msg.name:
|
||||||
|
watsonx_msg["name"] = msg.name
|
||||||
|
|
||||||
|
if msg.tool_calls:
|
||||||
|
watsonx_msg["tool_calls"] = msg.tool_calls
|
||||||
|
|
||||||
|
if msg.function_call:
|
||||||
|
watsonx_msg["function_call"] = msg.function_call
|
||||||
|
|
||||||
|
watsonx_messages.append(watsonx_msg)
|
||||||
|
|
||||||
|
return watsonx_messages
|
||||||
|
|
||||||
|
|
||||||
|
def transform_tools_to_watsonx(tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
|
||||||
|
"""Transform OpenAI tools to watsonx format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of OpenAI tool definitions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of watsonx-compatible tool definitions
|
||||||
|
"""
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# watsonx uses similar format to OpenAI for tools
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
def transform_watsonx_to_openai_chat(
|
||||||
|
watsonx_response: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
"""Transform watsonx chat response to OpenAI format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
watsonx_response: Response from watsonx chat API
|
||||||
|
model: Model name to include in response
|
||||||
|
request_id: Optional request ID, generates one if not provided
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI-compatible ChatCompletionResponse
|
||||||
|
"""
|
||||||
|
response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
# Extract choices
|
||||||
|
choices = []
|
||||||
|
watsonx_choices = watsonx_response.get("choices", [])
|
||||||
|
|
||||||
|
for idx, choice in enumerate(watsonx_choices):
|
||||||
|
message_data = choice.get("message", {})
|
||||||
|
|
||||||
|
message = ChatMessage(
|
||||||
|
role=message_data.get("role", "assistant"),
|
||||||
|
content=message_data.get("content"),
|
||||||
|
tool_calls=message_data.get("tool_calls"),
|
||||||
|
function_call=message_data.get("function_call"),
|
||||||
|
)
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
ChatCompletionChoice(
|
||||||
|
index=idx,
|
||||||
|
message=message,
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract usage
|
||||||
|
usage_data = watsonx_response.get("usage", {})
|
||||||
|
usage = ChatCompletionUsage(
|
||||||
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||||
|
total_tokens=usage_data.get("total_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
id=response_id,
|
||||||
|
created=created,
|
||||||
|
model=model,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_watsonx_to_openai_chat_chunk(
|
||||||
|
watsonx_chunk: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
request_id: str,
|
||||||
|
) -> ChatCompletionChunk:
|
||||||
|
"""Transform watsonx streaming chunk to OpenAI format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
watsonx_chunk: Streaming chunk from watsonx
|
||||||
|
model: Model name to include in response
|
||||||
|
request_id: Request ID for this stream
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI-compatible ChatCompletionChunk
|
||||||
|
"""
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
# Extract choices
|
||||||
|
choices = []
|
||||||
|
watsonx_choices = watsonx_chunk.get("choices", [])
|
||||||
|
|
||||||
|
for idx, choice in enumerate(watsonx_choices):
|
||||||
|
delta_data = choice.get("delta", {})
|
||||||
|
|
||||||
|
delta = ChatCompletionChunkDelta(
|
||||||
|
role=delta_data.get("role"),
|
||||||
|
content=delta_data.get("content"),
|
||||||
|
tool_calls=delta_data.get("tool_calls"),
|
||||||
|
function_call=delta_data.get("function_call"),
|
||||||
|
)
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
ChatCompletionChunkChoice(
|
||||||
|
index=idx,
|
||||||
|
delta=delta,
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatCompletionChunk(
|
||||||
|
id=request_id,
|
||||||
|
created=created,
|
||||||
|
model=model,
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_watsonx_to_openai_completion(
|
||||||
|
watsonx_response: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
) -> CompletionResponse:
|
||||||
|
"""Transform watsonx text generation response to OpenAI completion format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
watsonx_response: Response from watsonx text generation API
|
||||||
|
model: Model name to include in response
|
||||||
|
request_id: Optional request ID, generates one if not provided
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI-compatible CompletionResponse
|
||||||
|
"""
|
||||||
|
response_id = request_id or f"cmpl-{uuid.uuid4().hex[:24]}"
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
# Extract results
|
||||||
|
results = watsonx_response.get("results", [])
|
||||||
|
choices = []
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
choices.append(
|
||||||
|
CompletionChoice(
|
||||||
|
text=result.get("generated_text", ""),
|
||||||
|
index=idx,
|
||||||
|
finish_reason=result.get("stop_reason"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract usage
|
||||||
|
usage_data = watsonx_response.get("usage", {})
|
||||||
|
usage = ChatCompletionUsage(
|
||||||
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=usage_data.get("generated_tokens", 0),
|
||||||
|
total_tokens=usage_data.get("prompt_tokens", 0) + usage_data.get("generated_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return CompletionResponse(
|
||||||
|
id=response_id,
|
||||||
|
created=created,
|
||||||
|
model=model,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_watsonx_to_openai_embeddings(
|
||||||
|
watsonx_response: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
"""Transform watsonx embeddings response to OpenAI format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
watsonx_response: Response from watsonx embeddings API
|
||||||
|
model: Model name to include in response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI-compatible EmbeddingResponse
|
||||||
|
"""
|
||||||
|
# Extract results
|
||||||
|
results = watsonx_response.get("results", [])
|
||||||
|
data = []
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
embedding = result.get("embedding", [])
|
||||||
|
data.append(
|
||||||
|
EmbeddingData(
|
||||||
|
embedding=embedding,
|
||||||
|
index=idx,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate usage
|
||||||
|
input_token_count = watsonx_response.get("input_token_count", 0)
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
prompt_tokens=input_token_count,
|
||||||
|
total_tokens=input_token_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
return EmbeddingResponse(
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_sse_event(data: str) -> str:
|
||||||
|
"""Format data as Server-Sent Event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: JSON string to send
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted SSE string
|
||||||
|
"""
|
||||||
|
return f"data: {data}\n\n"
|
||||||
27
docker-compose.yml
Normal file
27
docker-compose.yml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
watsonx-openai-proxy:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
environment:
|
||||||
|
- IBM_CLOUD_API_KEY=${IBM_CLOUD_API_KEY}
|
||||||
|
- WATSONX_PROJECT_ID=${WATSONX_PROJECT_ID}
|
||||||
|
- WATSONX_CLUSTER=${WATSONX_CLUSTER:-us-south}
|
||||||
|
- HOST=0.0.0.0
|
||||||
|
- PORT=8000
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
- API_KEY=${API_KEY:-}
|
||||||
|
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-*}
|
||||||
|
- MODEL_MAP_GPT4=${MODEL_MAP_GPT4:-ibm/granite-4-h-small}
|
||||||
|
- MODEL_MAP_GPT35=${MODEL_MAP_GPT35:-ibm/granite-3-8b-instruct}
|
||||||
|
- MODEL_MAP_GPT4_TURBO=${MODEL_MAP_GPT4_TURBO:-meta-llama/llama-3-3-70b-instruct}
|
||||||
|
- MODEL_MAP_TEXT_EMBEDDING_ADA_002=${MODEL_MAP_TEXT_EMBEDDING_ADA_002:-ibm/slate-125m-english-rtrvr}
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8000/health')"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 5s
|
||||||
183
example_usage.py
Normal file
183
example_usage.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Example usage of watsonx-openai-proxy with OpenAI Python SDK."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Configure the client to use the proxy
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
api_key=os.getenv("API_KEY", "not-needed-if-proxy-has-no-auth"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def example_chat_completion():
|
||||||
|
"""Example: Basic chat completion."""
|
||||||
|
print("\n=== Chat Completion Example ===")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="ibm/granite-3-8b-instruct",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What is the capital of France?"},
|
||||||
|
],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Response: {response.choices[0].message.content}")
|
||||||
|
print(f"Tokens used: {response.usage.total_tokens}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_streaming_chat():
|
||||||
|
"""Example: Streaming chat completion."""
|
||||||
|
print("\n=== Streaming Chat Example ===")
|
||||||
|
|
||||||
|
stream = client.chat.completions.create(
|
||||||
|
model="ibm/granite-3-8b-instruct",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Tell me a short story about a robot."},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
max_tokens=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Response: ", end="", flush=True)
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def example_with_model_mapping():
|
||||||
|
"""Example: Using mapped model names."""
|
||||||
|
print("\n=== Model Mapping Example ===")
|
||||||
|
|
||||||
|
# If you configured MODEL_MAP_GPT4=ibm/granite-4-h-small in .env
|
||||||
|
# you can use "gpt-4" and it will be mapped automatically
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gpt-4", # This gets mapped to ibm/granite-4-h-small
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "Explain quantum computing in one sentence."},
|
||||||
|
],
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Response: {response.choices[0].message.content}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_embeddings():
|
||||||
|
"""Example: Generate embeddings."""
|
||||||
|
print("\n=== Embeddings Example ===")
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="ibm/slate-125m-english-rtrvr",
|
||||||
|
input=[
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
"Machine learning is a subset of artificial intelligence.",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Generated {len(response.data)} embeddings")
|
||||||
|
print(f"Embedding dimension: {len(response.data[0].embedding)}")
|
||||||
|
print(f"First embedding (first 5 values): {response.data[0].embedding[:5]}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_list_models():
|
||||||
|
"""Example: List available models."""
|
||||||
|
print("\n=== List Models Example ===")
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
|
||||||
|
print(f"Available models: {len(models.data)}")
|
||||||
|
print("\nFirst 5 models:")
|
||||||
|
for model in models.data[:5]:
|
||||||
|
print(f" - {model.id}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_completion_legacy():
|
||||||
|
"""Example: Legacy completion endpoint."""
|
||||||
|
print("\n=== Legacy Completion Example ===")
|
||||||
|
|
||||||
|
response = client.completions.create(
|
||||||
|
model="ibm/granite-3-8b-instruct",
|
||||||
|
prompt="Once upon a time, in a land far away,",
|
||||||
|
max_tokens=50,
|
||||||
|
temperature=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Completion: {response.choices[0].text}")
|
||||||
|
|
||||||
|
|
||||||
|
def example_with_functions():
|
||||||
|
"""Example: Function calling (if supported by model)."""
|
||||||
|
print("\n=== Function Calling Example ===")
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the current weather in a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="ibm/granite-3-8b-instruct",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "What's the weather like in Boston?"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
message = response.choices[0].message
|
||||||
|
if message.tool_calls:
|
||||||
|
print(f"Function called: {message.tool_calls[0].function.name}")
|
||||||
|
print(f"Arguments: {message.tool_calls[0].function.arguments}")
|
||||||
|
else:
|
||||||
|
print(f"Response: {message.content}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Function calling may not be supported by this model: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("watsonx-openai-proxy Usage Examples")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run examples
|
||||||
|
example_chat_completion()
|
||||||
|
example_streaming_chat()
|
||||||
|
example_embeddings()
|
||||||
|
example_list_models()
|
||||||
|
example_completion_legacy()
|
||||||
|
|
||||||
|
# Optional examples (may require specific configuration)
|
||||||
|
# example_with_model_mapping()
|
||||||
|
# example_with_functions()
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("All examples completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError: {e}")
|
||||||
|
print("\nMake sure:")
|
||||||
|
print("1. The proxy server is running (python -m app.main)")
|
||||||
|
print("2. Your .env file is configured correctly")
|
||||||
|
print("3. You have the OpenAI Python SDK installed (pip install openai)")
|
||||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
fastapi==0.115.0
|
||||||
|
uvicorn[standard]==0.32.0
|
||||||
|
pydantic==2.9.2
|
||||||
|
pydantic-settings==2.6.0
|
||||||
|
httpx==0.27.2
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
python-multipart==0.0.12
|
||||||
87
tests/test_basic.py
Normal file
87
tests/test_basic.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Basic tests for watsonx-openai-proxy."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_check():
|
||||||
|
"""Test the health check endpoint."""
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "cluster" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_root_endpoint():
|
||||||
|
"""Test the root endpoint."""
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["service"] == "watsonx-openai-proxy"
|
||||||
|
assert "endpoints" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_models():
|
||||||
|
"""Test listing available models."""
|
||||||
|
response = client.get("/v1/models")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["object"] == "list"
|
||||||
|
assert len(data["data"]) > 0
|
||||||
|
assert all(model["object"] == "model" for model in data["data"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_model():
|
||||||
|
"""Test retrieving a specific model."""
|
||||||
|
response = client.get("/v1/models/ibm/granite-3-8b-instruct")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == "ibm/granite-3-8b-instruct"
|
||||||
|
assert data["object"] == "model"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_nonexistent_model():
|
||||||
|
"""Test retrieving a model that doesn't exist."""
|
||||||
|
response = client.get("/v1/models/nonexistent-model")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The following tests require valid IBM Cloud credentials
|
||||||
|
# and should be run with pytest markers or in integration tests
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
|
||||||
|
def test_chat_completion():
|
||||||
|
"""Test chat completion endpoint."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "ibm/granite-3-8b-instruct",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello!"}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "choices" in data
|
||||||
|
assert len(data["choices"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
|
||||||
|
def test_embeddings():
|
||||||
|
"""Test embeddings endpoint."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/embeddings",
|
||||||
|
json={
|
||||||
|
"model": "ibm/slate-125m-english-rtrvr",
|
||||||
|
"input": "Test text",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "data" in data
|
||||||
|
assert len(data["data"]) > 0
|
||||||
Reference in New Issue
Block a user