Add AGENTS.md documentation for AI agent guidance

This commit is contained in:
2026-02-23 09:59:52 -05:00
commit 2e2b817435
21 changed files with 2513 additions and 0 deletions

223
AGENTS.md Normal file
View File

@@ -0,0 +1,223 @@
# AGENTS.md
This file provides guidance to agents when working with code in this repository.
## Project Overview
**watsonx-openai-proxy** is an OpenAI-compatible API proxy for IBM watsonx.ai. It enables any tool or application that supports the OpenAI API format to seamlessly work with watsonx.ai models.
### Core Purpose
- Provide drop-in replacement for OpenAI API endpoints
- Translate OpenAI API requests to watsonx.ai API calls
- Handle IBM Cloud authentication and token management automatically
- Support streaming responses via Server-Sent Events (SSE)
### Technology Stack
- **Framework**: FastAPI (async web framework)
- **Language**: Python 3.9+
- **HTTP Client**: httpx (async HTTP client)
- **Validation**: Pydantic v2 (data validation and settings)
- **Server**: uvicorn (ASGI server)
### Architecture
The codebase follows a clean, modular architecture:
```
app/
├── main.py # FastAPI app initialization, middleware, lifespan management
├── config.py # Settings management, model mapping, environment variables
├── routers/ # API endpoint handlers (chat, completions, embeddings, models)
├── services/ # Business logic (watsonx_service for API interactions)
├── models/ # Pydantic models for OpenAI-compatible schemas
└── utils/ # Helper functions (request/response transformers)
```
**Key Design Patterns**:
- **Service Layer**: `watsonx_service.py` encapsulates all watsonx.ai API interactions
- **Transformer Pattern**: `transformers.py` handles bidirectional conversion between OpenAI and watsonx formats
- **Singleton Services**: Global service instances (`watsonx_service`, `settings`) for shared state
- **Async/Await**: All I/O operations are asynchronous for better performance
- **Middleware**: Custom authentication middleware for optional API key validation
## Building and Running
### Prerequisites
```bash
# Python 3.9 or higher required
python --version
# IBM Cloud credentials needed:
# - IBM_CLOUD_API_KEY
# - WATSONX_PROJECT_ID
```
### Installation
```bash
# Install dependencies
pip install -r requirements.txt
# Configure environment
cp .env.example .env
# Edit .env with your IBM Cloud credentials
```
### Running the Server
```bash
# Development (with auto-reload)
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
# Production (with workers)
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
# Using Python module
python -m app.main
```
### Docker Deployment
```bash
# Build image
docker build -t watsonx-openai-proxy .
# Run container
docker run -p 8000:8000 --env-file .env watsonx-openai-proxy
# Using docker-compose
docker-compose up
```
### Testing
```bash
# Install test dependencies
pip install pytest pytest-asyncio httpx
# Run tests
pytest tests/
# Run with coverage
pytest tests/ --cov=app
```
## Development Conventions
### Code Style
- **Async First**: Use `async`/`await` for all I/O operations (HTTP requests, file operations)
- **Type Hints**: All functions should have type annotations for parameters and return values
- **Docstrings**: Use Google-style docstrings for functions and classes
- **Logging**: Use the `logging` module with appropriate log levels (info, warning, error)
### Error Handling
- Catch exceptions at router level and return OpenAI-compatible error responses
- Use `HTTPException` with proper status codes and error details
- Log errors with full context using `logger.error(..., exc_info=True)`
- Return structured error responses matching OpenAI's error format
### Configuration Management
- All configuration via environment variables (`.env` file)
- Use `pydantic-settings` for type-safe configuration
- Model mapping via `MODEL_MAP_*` environment variables
- Settings accessed through global `settings` instance
### Token Management
- Bearer tokens automatically refreshed every 50 minutes (expire at 60 minutes)
- Token refresh on 401 errors from watsonx.ai
- Thread-safe token refresh using `asyncio.Lock`
- Initial token obtained during application startup
### API Compatibility
- Maintain strict OpenAI API compatibility in request/response formats
- Use Pydantic models from `openai_models.py` for validation
- Transform requests/responses using functions in `transformers.py`
- Support both streaming and non-streaming responses
### Adding New Endpoints
1. Create router in `app/routers/` (e.g., `new_endpoint.py`)
2. Define Pydantic models in `app/models/openai_models.py`
3. Add transformation logic in `app/utils/transformers.py`
4. Add watsonx.ai API method in `app/services/watsonx_service.py`
5. Register router in `app/main.py` using `app.include_router()`
### Streaming Responses
- Use `StreamingResponse` with `media_type="text/event-stream"`
- Format chunks as Server-Sent Events using `format_sse_event()`
- Always send `[DONE]` message at the end of stream
- Handle errors gracefully and send error events in SSE format
### Model Mapping
- Map OpenAI model names to watsonx models via environment variables
- Format: `MODEL_MAP_<OPENAI_MODEL>=<WATSONX_MODEL_ID>`
- Example: `MODEL_MAP_GPT4=ibm/granite-4-h-small`
- Mapping applied in `settings.map_model()` before API calls
### Security Considerations
- Optional API key authentication via `API_KEY` environment variable
- Middleware validates Bearer token in Authorization header
- IBM Cloud API key stored securely in environment variables
- CORS configured via `ALLOWED_ORIGINS` (default: `*`)
### Logging Best Practices
- Use structured logging with context (model names, request IDs)
- Log level controlled by `LOG_LEVEL` environment variable
- Log token refresh events at INFO level
- Log API errors at ERROR level with full traceback
- Include request/response details for debugging
### Dependencies
- Keep `requirements.txt` minimal and pinned to specific versions
- FastAPI and Pydantic are core dependencies - avoid breaking changes
- httpx for async HTTP - prefer over requests/aiohttp
- Use `uvicorn[standard]` for production-ready server
## Important Implementation Notes
### watsonx.ai API Specifics
- Base URL format: `https://{cluster}.ml.cloud.ibm.com/ml/v1`
- API version parameter: `version=2024-02-13` (required on all requests)
- Chat endpoint: `/text/chat` (non-streaming) or `/text/chat_stream` (streaming)
- Text generation: `/text/generation`
- Embeddings: `/text/embeddings`
### Request/Response Transformation
- OpenAI messages → watsonx messages: Direct mapping with role/content
- watsonx responses → OpenAI format: Extract choices, usage, and metadata
- Streaming chunks: Parse SSE format, transform delta objects
- Generate unique IDs: `chatcmpl-{uuid}` for chat, `cmpl-{uuid}` for completions
### Common Pitfalls
- Don't forget to refresh tokens before they expire (50-minute interval)
- Always close httpx client on shutdown (`await watsonx_service.close()`)
- Handle both string and list formats for `stop` parameter
- Validate model IDs exist in watsonx.ai before making requests
- Set appropriate timeouts for long-running generation requests (300s default)
### Performance Optimization
- Reuse httpx client instance (don't create per request)
- Use connection pooling (httpx default behavior)
- Consider worker processes for production (`--workers 4`)
- Monitor token refresh to avoid rate limiting
## Environment Variables Reference
### Required
- `IBM_CLOUD_API_KEY`: IBM Cloud API key for authentication
- `WATSONX_PROJECT_ID`: watsonx.ai project ID
### Optional
- `WATSONX_CLUSTER`: Region (default: `us-south`)
- `HOST`: Server host (default: `0.0.0.0`)
- `PORT`: Server port (default: `8000`)
- `LOG_LEVEL`: Logging level (default: `info`)
- `API_KEY`: Optional proxy authentication key
- `ALLOWED_ORIGINS`: CORS origins (default: `*`)
- `MODEL_MAP_*`: Model name mappings
## API Endpoints
- `GET /` - API information and available endpoints
- `GET /health` - Health check (bypasses authentication)
- `GET /docs` - Interactive Swagger UI documentation
- `POST /v1/chat/completions` - Chat completions (streaming supported)
- `POST /v1/completions` - Text completions (legacy)
- `POST /v1/embeddings` - Generate embeddings
- `GET /v1/models` - List available models
- `GET /v1/models/{model_id}` - Get specific model info

20
Dockerfile Normal file
View File

@@ -0,0 +1,20 @@
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY app ./app
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:8000/health')"
# Run the application
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

353
README.md Normal file
View File

@@ -0,0 +1,353 @@
# watsonx-openai-proxy
OpenAI-compatible API proxy for IBM watsonx.ai. This proxy allows you to use watsonx.ai models with any tool or application that supports the OpenAI API format.
## Features
-**Full OpenAI API Compatibility**: Drop-in replacement for OpenAI API
-**Chat Completions**: `/v1/chat/completions` with streaming support
-**Text Completions**: `/v1/completions` (legacy endpoint)
-**Embeddings**: `/v1/embeddings` for text embeddings
-**Model Listing**: `/v1/models` endpoint
-**Streaming Support**: Server-Sent Events (SSE) for real-time responses
-**Model Mapping**: Map OpenAI model names to watsonx models
-**Automatic Token Management**: Handles IBM Cloud authentication automatically
-**CORS Support**: Configurable cross-origin resource sharing
-**Optional API Key Authentication**: Secure your proxy with an API key
## Quick Start
### Prerequisites
- Python 3.9 or higher
- IBM Cloud account with watsonx.ai access
- IBM Cloud API key
- watsonx.ai Project ID
### Installation
1. Clone or download this directory:
```bash
cd watsonx-openai-proxy
```
2. Install dependencies:
```bash
pip install -r requirements.txt
```
3. Configure environment variables:
```bash
cp .env.example .env
# Edit .env with your credentials
```
4. Run the server:
```bash
python -m app.main
```
Or with uvicorn:
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8000
```
The server will start at `http://localhost:8000`
## Configuration
### Environment Variables
Create a `.env` file with the following variables:
```bash
# Required: IBM Cloud Configuration
IBM_CLOUD_API_KEY=your_ibm_cloud_api_key_here
WATSONX_PROJECT_ID=your_watsonx_project_id_here
WATSONX_CLUSTER=us-south # Options: us-south, eu-de, eu-gb, jp-tok, au-syd, ca-tor
# Optional: Server Configuration
HOST=0.0.0.0
PORT=8000
LOG_LEVEL=info
# Optional: API Key for Proxy Authentication
API_KEY=your_optional_api_key_for_proxy_authentication
# Optional: CORS Configuration
ALLOWED_ORIGINS=* # Comma-separated or * for all
# Optional: Model Mapping
MODEL_MAP_GPT4=ibm/granite-4-h-small
MODEL_MAP_GPT35=ibm/granite-3-8b-instruct
MODEL_MAP_GPT4_TURBO=meta-llama/llama-3-3-70b-instruct
MODEL_MAP_TEXT_EMBEDDING_ADA_002=ibm/slate-125m-english-rtrvr
```
### Model Mapping
You can map OpenAI model names to watsonx models using environment variables:
```bash
MODEL_MAP_<OPENAI_MODEL_NAME>=<WATSONX_MODEL_ID>
```
For example:
- `MODEL_MAP_GPT4=ibm/granite-4-h-small` maps `gpt-4` to `ibm/granite-4-h-small`
- `MODEL_MAP_GPT35_TURBO=ibm/granite-3-8b-instruct` maps `gpt-3.5-turbo` to `ibm/granite-3-8b-instruct`
## Usage
### With OpenAI Python SDK
```python
from openai import OpenAI
# Point to your proxy
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="your-proxy-api-key" # Optional, if you set API_KEY in .env
)
# Use as normal
response = client.chat.completions.create(
model="ibm/granite-3-8b-instruct", # Or use mapped name like "gpt-4"
messages=[
{"role": "user", "content": "Hello, how are you?"}
]
)
print(response.choices[0].message.content)
```
### With Streaming
```python
stream = client.chat.completions.create(
model="ibm/granite-3-8b-instruct",
messages=[{"role": "user", "content": "Tell me a story"}],
stream=True
)
for chunk in stream:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="")
```
### With cURL
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer your-proxy-api-key" \
-d '{
"model": "ibm/granite-3-8b-instruct",
"messages": [
{"role": "user", "content": "Hello!"}
]
}'
```
### Embeddings
```python
response = client.embeddings.create(
model="ibm/slate-125m-english-rtrvr",
input="Your text to embed"
)
print(response.data[0].embedding)
```
## Available Endpoints
- `GET /` - API information
- `GET /health` - Health check
- `GET /docs` - Interactive API documentation (Swagger UI)
- `POST /v1/chat/completions` - Chat completions
- `POST /v1/completions` - Text completions (legacy)
- `POST /v1/embeddings` - Generate embeddings
- `GET /v1/models` - List available models
- `GET /v1/models/{model_id}` - Get model information
## Supported Models
The proxy supports all watsonx.ai models available in your project, including:
### Chat Models
- IBM Granite models (3.x, 4.x series)
- Meta Llama models (3.x, 4.x series)
- Mistral models
- Other models available on watsonx.ai
### Embedding Models
- `ibm/slate-125m-english-rtrvr`
- `ibm/slate-30m-english-rtrvr`
See `/v1/models` endpoint for the complete list.
## Authentication
### Proxy Authentication (Optional)
If you set `API_KEY` in your `.env` file, clients must provide it:
```python
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="your-proxy-api-key"
)
```
### IBM Cloud Authentication
The proxy handles IBM Cloud authentication automatically using your `IBM_CLOUD_API_KEY`. Bearer tokens are:
- Automatically obtained on startup
- Refreshed every 50 minutes (tokens expire after 60 minutes)
- Refreshed on 401 errors
## Deployment
### Docker (Recommended)
Create a `Dockerfile`:
```dockerfile
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY app ./app
COPY .env .
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
```
Build and run:
```bash
docker build -t watsonx-openai-proxy .
docker run -p 8000:8000 --env-file .env watsonx-openai-proxy
```
### Production Deployment
For production, consider:
1. **Use a production ASGI server**: The included uvicorn is suitable, but configure workers:
```bash
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
```
2. **Set up HTTPS**: Use a reverse proxy like nginx or Caddy
3. **Configure CORS**: Set `ALLOWED_ORIGINS` to specific domains
4. **Enable API key authentication**: Set `API_KEY` in environment
5. **Monitor logs**: Set `LOG_LEVEL=info` or `warning` in production
6. **Use environment secrets**: Don't commit `.env` file, use secret management
## Troubleshooting
### 401 Unauthorized
- Check that `IBM_CLOUD_API_KEY` is valid
- Verify your IBM Cloud account has watsonx.ai access
- Check server logs for token refresh errors
### Model Not Found
- Verify the model ID exists in watsonx.ai
- Check that your project has access to the model
- Use `/v1/models` endpoint to see available models
### Connection Errors
- Verify `WATSONX_CLUSTER` matches your project's region
- Check firewall/network settings
- Ensure watsonx.ai services are accessible
### Streaming Issues
- Some models may not support streaming
- Check client library supports SSE (Server-Sent Events)
- Verify network doesn't buffer streaming responses
## Development
### Running Tests
```bash
# Install dev dependencies
pip install pytest pytest-asyncio httpx
# Run tests
pytest tests/
```
### Code Structure
```
watsonx-openai-proxy/
├── app/
│ ├── main.py # FastAPI application
│ ├── config.py # Configuration management
│ ├── routers/ # API endpoint routers
│ │ ├── chat.py # Chat completions
│ │ ├── completions.py # Text completions
│ │ ├── embeddings.py # Embeddings
│ │ └── models.py # Model listing
│ ├── services/ # Business logic
│ │ └── watsonx_service.py # watsonx.ai API client
│ ├── models/ # Pydantic models
│ │ └── openai_models.py # OpenAI-compatible schemas
│ └── utils/ # Utilities
│ └── transformers.py # Request/response transformers
├── tests/ # Test files
├── requirements.txt # Python dependencies
├── .env.example # Environment template
└── README.md # This file
```
## Contributing
Contributions are welcome! Please:
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests if applicable
5. Submit a pull request
## License
Apache 2.0 License - See LICENSE file for details.
## Related Projects
- [watsonx-unofficial-aisdk-provider](../wxai-provider/) - Vercel AI SDK provider for watsonx.ai
- [OpenCode watsonx plugin](../.opencode/plugins/) - Token management plugin for OpenCode
## Disclaimer
This is **not an official IBM product**. It's a community-maintained proxy for integrating watsonx.ai with OpenAI-compatible tools. watsonx.ai is a trademark of IBM.
## Support
For issues and questions:
- Check the [Troubleshooting](#troubleshooting) section
- Review server logs (`LOG_LEVEL=debug` for detailed logs)
- Open an issue in the repository
- Consult [IBM watsonx.ai documentation](https://www.ibm.com/docs/en/watsonx-as-a-service)

3
app/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""watsonx-openai-proxy application package."""
__version__ = "1.0.0"

91
app/config.py Normal file
View File

@@ -0,0 +1,91 @@
"""Configuration management for watsonx-openai-proxy."""
import os
from typing import Dict, Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# IBM Cloud Configuration
ibm_cloud_api_key: str
watsonx_project_id: str
watsonx_cluster: str = "us-south"
# Server Configuration
host: str = "0.0.0.0"
port: int = 8000
log_level: str = "info"
# API Configuration
api_key: Optional[str] = None
allowed_origins: str = "*"
# Token Management
token_refresh_interval: int = 3000 # 50 minutes in seconds
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="allow", # Allow extra fields for model mapping
)
@property
def watsonx_base_url(self) -> str:
"""Construct the watsonx.ai base URL from cluster."""
return f"https://{self.watsonx_cluster}.ml.cloud.ibm.com/ml/v1"
@property
def cors_origins(self) -> list[str]:
"""Parse CORS origins from comma-separated string."""
if self.allowed_origins == "*":
return ["*"]
return [origin.strip() for origin in self.allowed_origins.split(",")]
def get_model_mapping(self) -> Dict[str, str]:
"""Extract model mappings from environment variables.
Looks for variables like MODEL_MAP_GPT4=ibm/granite-4-h-small
and creates a mapping dict.
"""
import re
mapping = {}
# Check os.environ first
for key, value in os.environ.items():
if key.startswith("MODEL_MAP_"):
model_name = key.replace("MODEL_MAP_", "")
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
openai_model = model_name.lower().replace("_", "-")
mapping[openai_model] = value
# Also check pydantic's extra fields (from .env file)
# These come in as lowercase: model_map_gpt4 instead of MODEL_MAP_GPT4
extra = getattr(self, '__pydantic_extra__', {}) or {}
for key, value in extra.items():
if key.startswith("model_map_"):
# Convert back to uppercase for processing
model_name = key.replace("model_map_", "").upper()
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
openai_model = model_name.lower().replace("_", "-")
mapping[openai_model] = value
return mapping
def map_model(self, openai_model: str) -> str:
"""Map an OpenAI model name to a watsonx model ID.
Args:
openai_model: OpenAI model name (e.g., "gpt-4", "gpt-3.5-turbo")
Returns:
Corresponding watsonx model ID, or the original name if no mapping exists
"""
model_map = self.get_model_mapping()
return model_map.get(openai_model, openai_model)
# Global settings instance
settings = Settings()

157
app/main.py Normal file
View File

@@ -0,0 +1,157 @@
"""Main FastAPI application for watsonx-openai-proxy."""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.config import settings
from app.routers import chat, completions, embeddings, models
from app.services.watsonx_service import watsonx_service
# Configure logging
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan events."""
# Startup
logger.info("Starting watsonx-openai-proxy...")
logger.info(f"Cluster: {settings.watsonx_cluster}")
logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...")
# Initialize token
try:
await watsonx_service._refresh_token()
logger.info("Initial bearer token obtained successfully")
except Exception as e:
logger.error(f"Failed to obtain initial bearer token: {e}")
raise
yield
# Shutdown
logger.info("Shutting down watsonx-openai-proxy...")
await watsonx_service.close()
# Create FastAPI app
app = FastAPI(
title="watsonx-openai-proxy",
description="OpenAI-compatible API proxy for IBM watsonx.ai",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Optional API key authentication middleware
@app.middleware("http")
async def authenticate(request: Request, call_next):
"""Authenticate requests if API key is configured."""
if settings.api_key:
# Skip authentication for health check
if request.url.path == "/health":
return await call_next(request)
# Check Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Missing Authorization header",
"type": "authentication_error",
"code": "missing_authorization",
}
},
)
# Validate API key
if not auth_header.startswith("Bearer "):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Invalid Authorization header format",
"type": "authentication_error",
"code": "invalid_authorization_format",
}
},
)
token = auth_header[7:] # Remove "Bearer " prefix
if token != settings.api_key:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Invalid API key",
"type": "authentication_error",
"code": "invalid_api_key",
}
},
)
return await call_next(request)
# Include routers
app.include_router(chat.router, tags=["Chat"])
app.include_router(completions.router, tags=["Completions"])
app.include_router(embeddings.router, tags=["Embeddings"])
app.include_router(models.router, tags=["Models"])
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"service": "watsonx-openai-proxy",
"cluster": settings.watsonx_cluster,
}
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"service": "watsonx-openai-proxy",
"description": "OpenAI-compatible API proxy for IBM watsonx.ai",
"version": "1.0.0",
"endpoints": {
"chat": "/v1/chat/completions",
"completions": "/v1/completions",
"embeddings": "/v1/embeddings",
"models": "/v1/models",
"health": "/health",
},
"documentation": "/docs",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
log_level=settings.log_level,
reload=False,
)

31
app/models/__init__.py Normal file
View File

@@ -0,0 +1,31 @@
"""OpenAI-compatible data models."""
from app.models.openai_models import (
ChatMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
CompletionRequest,
CompletionResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelsResponse,
ModelInfo,
ErrorResponse,
ErrorDetail,
)
__all__ = [
"ChatMessage",
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionChunk",
"CompletionRequest",
"CompletionResponse",
"EmbeddingRequest",
"EmbeddingResponse",
"ModelsResponse",
"ModelInfo",
"ErrorResponse",
"ErrorDetail",
]

213
app/models/openai_models.py Normal file
View File

@@ -0,0 +1,213 @@
"""OpenAI-compatible request and response models."""
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
# ============================================================================
# Chat Completions Models
# ============================================================================
class ChatMessage(BaseModel):
"""A chat message in the conversation."""
role: Literal["system", "user", "assistant", "function", "tool"]
content: Optional[str] = None
name: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
class FunctionCall(BaseModel):
"""Function call specification."""
name: str
arguments: str
class ToolCall(BaseModel):
"""Tool call specification."""
id: str
type: Literal["function"]
function: FunctionCall
class ChatCompletionRequest(BaseModel):
"""OpenAI chat completion request."""
model: str
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
n: Optional[int] = Field(default=1, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=None, ge=1)
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
functions: Optional[List[Dict[str, Any]]] = None
function_call: Optional[Union[str, Dict[str, str]]] = None
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
class ChatCompletionChoice(BaseModel):
"""A single chat completion choice."""
index: int
message: ChatMessage
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
class ChatCompletionUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
"""OpenAI chat completion response."""
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
system_fingerprint: Optional[str] = None
class ChatCompletionChunkDelta(BaseModel):
"""Delta content in streaming response."""
role: Optional[str] = None
content: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
class ChatCompletionChunkChoice(BaseModel):
"""A single streaming chunk choice."""
index: int
delta: ChatCompletionChunkDelta
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
class ChatCompletionChunk(BaseModel):
"""OpenAI streaming chat completion chunk."""
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
choices: List[ChatCompletionChunkChoice]
system_fingerprint: Optional[str] = None
# ============================================================================
# Completions Models (Legacy)
# ============================================================================
class CompletionRequest(BaseModel):
"""OpenAI completion request (legacy)."""
model: str
prompt: Union[str, List[str], List[int], List[List[int]]]
suffix: Optional[str] = None
max_tokens: Optional[int] = Field(default=16, ge=1)
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
n: Optional[int] = Field(default=1, ge=1)
stream: Optional[bool] = False
logprobs: Optional[int] = Field(default=None, ge=0, le=5)
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
best_of: Optional[int] = Field(default=1, ge=1)
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
class CompletionChoice(BaseModel):
"""A single completion choice."""
text: str
index: int
logprobs: Optional[Dict[str, Any]] = None
finish_reason: Optional[str] = None
class CompletionResponse(BaseModel):
"""OpenAI completion response."""
id: str
object: Literal["text_completion"] = "text_completion"
created: int
model: str
choices: List[CompletionChoice]
usage: ChatCompletionUsage
# ============================================================================
# Embeddings Models
# ============================================================================
class EmbeddingRequest(BaseModel):
"""OpenAI embedding request."""
model: str
input: Union[str, List[str], List[int], List[List[int]]]
encoding_format: Optional[Literal["float", "base64"]] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
class EmbeddingData(BaseModel):
"""A single embedding result."""
object: Literal["embedding"] = "embedding"
embedding: List[float]
index: int
class EmbeddingUsage(BaseModel):
"""Token usage for embeddings."""
prompt_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
"""OpenAI embedding response."""
object: Literal["list"] = "list"
data: List[EmbeddingData]
model: str
usage: EmbeddingUsage
# ============================================================================
# Models List
# ============================================================================
class ModelInfo(BaseModel):
"""Information about a single model."""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
class ModelsResponse(BaseModel):
"""List of available models."""
object: Literal["list"] = "list"
data: List[ModelInfo]
# ============================================================================
# Error Models
# ============================================================================
class ErrorDetail(BaseModel):
"""Error detail information."""
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
class ErrorResponse(BaseModel):
"""OpenAI error response."""
error: ErrorDetail

5
app/routers/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""API routers for watsonx-openai-proxy."""
from app.routers import chat, completions, embeddings, models
__all__ = ["chat", "completions", "embeddings", "models"]

156
app/routers/chat.py Normal file
View File

@@ -0,0 +1,156 @@
"""Chat completions endpoint router."""
import json
import uuid
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from app.models.openai_models import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import (
transform_messages_to_watsonx,
transform_tools_to_watsonx,
transform_watsonx_to_openai_chat,
transform_watsonx_to_openai_chat_chunk,
format_sse_event,
)
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/v1/chat/completions",
response_model=Union[ChatCompletionResponse, ErrorResponse],
responses={
200: {"model": ChatCompletionResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_chat_completion(
request: ChatCompletionRequest,
http_request: Request,
):
"""Create a chat completion using OpenAI-compatible API.
This endpoint accepts OpenAI-formatted requests and translates them
to watsonx.ai API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
# Transform messages
watsonx_messages = transform_messages_to_watsonx(request.messages)
# Transform tools if present
watsonx_tools = transform_tools_to_watsonx(request.tools)
# Handle streaming
if request.stream:
return StreamingResponse(
stream_chat_completion(
watsonx_model,
watsonx_messages,
request,
watsonx_tools,
),
media_type="text/event-stream",
)
# Non-streaming response
watsonx_response = await watsonx_service.chat_completion(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
)
# Transform response
openai_response = transform_watsonx_to_openai_chat(
watsonx_response,
request.model,
)
return openai_response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)
async def stream_chat_completion(
watsonx_model: str,
watsonx_messages: list,
request: ChatCompletionRequest,
watsonx_tools: list = None,
):
"""Stream chat completion responses.
Args:
watsonx_model: The watsonx model ID
watsonx_messages: Transformed messages
request: Original OpenAI request
watsonx_tools: Transformed tools
Yields:
Server-Sent Events with chat completion chunks
"""
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
try:
async for chunk in watsonx_service.chat_completion_stream(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
):
# Transform chunk to OpenAI format
openai_chunk = transform_watsonx_to_openai_chat_chunk(
chunk,
request.model,
request_id,
)
# Send as SSE
yield format_sse_event(openai_chunk.model_dump_json())
# Send [DONE] message
yield format_sse_event("[DONE]")
except Exception as e:
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
error_response = ErrorResponse(
error=ErrorDetail(
message=str(e),
type="internal_error",
code="stream_error",
)
)
yield format_sse_event(error_response.model_dump_json())

109
app/routers/completions.py Normal file
View File

@@ -0,0 +1,109 @@
"""Text completions endpoint router (legacy)."""
import uuid
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from app.models.openai_models import (
CompletionRequest,
CompletionResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import transform_watsonx_to_openai_completion
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/v1/completions",
response_model=Union[CompletionResponse, ErrorResponse],
responses={
200: {"model": CompletionResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_completion(
request: CompletionRequest,
http_request: Request,
):
"""Create a text completion using OpenAI-compatible API (legacy).
This endpoint accepts OpenAI-formatted completion requests and translates
them to watsonx.ai text generation API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Completion request: {request.model} -> {watsonx_model}")
# Handle prompt (can be string or list)
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Prompt cannot be empty",
"type": "invalid_request_error",
"code": "invalid_prompt",
}
},
)
# For now, just use the first prompt
# TODO: Handle multiple prompts with n parameter
prompt = request.prompt[0] if isinstance(request.prompt[0], str) else ""
else:
prompt = request.prompt
# Note: Streaming not implemented for completions yet
if request.stream:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Streaming not supported for completions endpoint",
"type": "invalid_request_error",
"code": "streaming_not_supported",
}
},
)
# Call watsonx text generation
watsonx_response = await watsonx_service.text_generation(
model_id=watsonx_model,
prompt=prompt,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
)
# Transform response
openai_response = transform_watsonx_to_openai_completion(
watsonx_response,
request.model,
)
return openai_response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in completion: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)

114
app/routers/embeddings.py Normal file
View File

@@ -0,0 +1,114 @@
"""Embeddings endpoint router."""
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from app.models.openai_models import (
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import transform_watsonx_to_openai_embeddings
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/v1/embeddings",
response_model=Union[EmbeddingResponse, ErrorResponse],
responses={
200: {"model": EmbeddingResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_embeddings(
request: EmbeddingRequest,
http_request: Request,
):
"""Create embeddings using OpenAI-compatible API.
This endpoint accepts OpenAI-formatted embedding requests and translates
them to watsonx.ai embeddings API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Embeddings request: {request.model} -> {watsonx_model}")
# Handle input (can be string or list)
if isinstance(request.input, str):
inputs = [request.input]
elif isinstance(request.input, list):
if len(request.input) == 0:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Input cannot be empty",
"type": "invalid_request_error",
"code": "invalid_input",
}
},
)
# Handle list of strings or list of token IDs
if isinstance(request.input[0], str):
inputs = request.input
else:
# Token IDs not supported yet
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Token ID input not supported",
"type": "invalid_request_error",
"code": "unsupported_input_type",
}
},
)
else:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Invalid input type",
"type": "invalid_request_error",
"code": "invalid_input_type",
}
},
)
# Call watsonx embeddings
watsonx_response = await watsonx_service.embeddings(
model_id=watsonx_model,
inputs=inputs,
)
# Transform response
openai_response = transform_watsonx_to_openai_embeddings(
watsonx_response,
request.model,
)
return openai_response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in embeddings: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)

120
app/routers/models.py Normal file
View File

@@ -0,0 +1,120 @@
"""Models endpoint router."""
import time
from fastapi import APIRouter
from app.models.openai_models import ModelsResponse, ModelInfo
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
# Predefined list of available models
# This can be extended or made dynamic based on watsonx.ai API
AVAILABLE_MODELS = [
# Granite Models
"ibm/granite-3-1-8b-base",
"ibm/granite-3-2-8b-instruct",
"ibm/granite-3-3-8b-instruct",
"ibm/granite-3-8b-instruct",
"ibm/granite-4-h-small",
"ibm/granite-8b-code-instruct",
# Llama Models
"meta-llama/llama-3-1-70b-gptq",
"meta-llama/llama-3-1-8b",
"meta-llama/llama-3-2-11b-vision-instruct",
"meta-llama/llama-3-2-90b-vision-instruct",
"meta-llama/llama-3-3-70b-instruct",
"meta-llama/llama-3-405b-instruct",
"meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
# Mistral Models
"mistral-large-2512",
"mistralai/mistral-medium-2505",
"mistralai/mistral-small-3-1-24b-instruct-2503",
# Other Models
"openai/gpt-oss-120b",
# Embedding Models
"ibm/slate-125m-english-rtrvr",
"ibm/slate-30m-english-rtrvr",
]
@router.get(
"/v1/models",
response_model=ModelsResponse,
)
async def list_models():
"""List available models in OpenAI-compatible format.
Returns a list of models that can be used with the API.
Includes both the actual watsonx model IDs and any mapped names.
"""
created_time = int(time.time())
models = []
# Add all available watsonx models
for model_id in AVAILABLE_MODELS:
models.append(
ModelInfo(
id=model_id,
created=created_time,
owned_by="ibm-watsonx",
)
)
# Add mapped model names (e.g., gpt-4 -> ibm/granite-4-h-small)
model_mapping = settings.get_model_mapping()
for openai_name, watsonx_id in model_mapping.items():
if watsonx_id in AVAILABLE_MODELS:
models.append(
ModelInfo(
id=openai_name,
created=created_time,
owned_by="ibm-watsonx",
)
)
return ModelsResponse(data=models)
@router.get(
"/v1/models/{model_id}",
response_model=ModelInfo,
)
async def retrieve_model(model_id: str):
"""Retrieve information about a specific model.
Args:
model_id: The model ID to retrieve
Returns:
Model information
"""
# Map the model if needed
watsonx_model = settings.map_model(model_id)
# Check if model exists
if watsonx_model not in AVAILABLE_MODELS:
from fastapi import HTTPException
raise HTTPException(
status_code=404,
detail={
"error": {
"message": f"Model '{model_id}' not found",
"type": "invalid_request_error",
"code": "model_not_found",
}
},
)
return ModelInfo(
id=model_id,
created=int(time.time()),
owned_by="ibm-watsonx",
)

5
app/services/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Services for watsonx-openai-proxy."""
from app.services.watsonx_service import watsonx_service, WatsonxService
__all__ = ["watsonx_service", "WatsonxService"]

View File

@@ -0,0 +1,316 @@
"""Service for interacting with IBM watsonx.ai APIs."""
import asyncio
import time
from typing import AsyncIterator, Dict, List, Optional
import httpx
from app.config import settings
import logging
logger = logging.getLogger(__name__)
class WatsonxService:
"""Service for managing watsonx.ai API interactions."""
def __init__(self):
self.base_url = settings.watsonx_base_url
self.project_id = settings.watsonx_project_id
self.api_key = settings.ibm_cloud_api_key
self._bearer_token: Optional[str] = None
self._token_expiry: Optional[float] = None
self._token_lock = asyncio.Lock()
self._client: Optional[httpx.AsyncClient] = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
self._client = httpx.AsyncClient(timeout=300.0)
return self._client
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _refresh_token(self) -> str:
"""Get a fresh bearer token from IBM Cloud IAM."""
async with self._token_lock:
# Check if token is still valid
if self._bearer_token and self._token_expiry:
if time.time() < self._token_expiry - 300: # 5 min buffer
return self._bearer_token
logger.info("Refreshing IBM Cloud bearer token...")
client = await self._get_client()
response = await client.post(
"https://iam.cloud.ibm.com/identity/token",
headers={"Content-Type": "application/x-www-form-urlencoded"},
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}",
)
if response.status_code != 200:
raise Exception(f"Failed to get bearer token: {response.text}")
data = response.json()
self._bearer_token = data["access_token"]
self._token_expiry = time.time() + data.get("expires_in", 3600)
logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds")
return self._bearer_token
async def _get_headers(self) -> Dict[str, str]:
"""Get headers with valid bearer token."""
token = await self._refresh_token()
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
async def chat_completion(
self,
model_id: str,
messages: List[Dict],
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
stream: bool = False,
tools: Optional[List[Dict]] = None,
**kwargs,
) -> Dict:
"""Create a chat completion using watsonx.ai.
Args:
model_id: The watsonx model ID
messages: List of chat messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
stream: Whether to stream the response
tools: Tool/function definitions
**kwargs: Additional parameters
Returns:
Chat completion response
"""
headers = await self._get_headers()
client = await self._get_client()
# Build watsonx request
payload = {
"model_id": model_id,
"project_id": self.project_id,
"messages": messages,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
if tools:
payload["tools"] = tools
url = f"{self.base_url}/text/chat"
params = {"version": "2024-02-13"}
if stream:
params["stream"] = "true"
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
async def chat_completion_stream(
self,
model_id: str,
messages: List[Dict],
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
tools: Optional[List[Dict]] = None,
**kwargs,
) -> AsyncIterator[Dict]:
"""Stream a chat completion using watsonx.ai.
Args:
model_id: The watsonx model ID
messages: List of chat messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
tools: Tool/function definitions
**kwargs: Additional parameters
Yields:
Chat completion chunks
"""
headers = await self._get_headers()
client = await self._get_client()
# Build watsonx request
payload = {
"model_id": model_id,
"project_id": self.project_id,
"messages": messages,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
if tools:
payload["tools"] = tools
url = f"{self.base_url}/text/chat_stream"
params = {"version": "2024-02-13"}
async with client.stream(
"POST",
url,
headers=headers,
json=payload,
params=params,
) as response:
if response.status_code != 200:
text = await response.aread()
raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}")
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
if data.strip() == "[DONE]":
break
try:
import json
yield json.loads(data)
except json.JSONDecodeError:
continue
async def text_generation(
self,
model_id: str,
prompt: str,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
**kwargs,
) -> Dict:
"""Generate text completion using watsonx.ai.
Args:
model_id: The watsonx model ID
prompt: The input prompt
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
**kwargs: Additional parameters
Returns:
Text generation response
"""
headers = await self._get_headers()
client = await self._get_client()
payload = {
"model_id": model_id,
"project_id": self.project_id,
"input": prompt,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_new_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
url = f"{self.base_url}/text/generation"
params = {"version": "2024-02-13"}
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
async def embeddings(
self,
model_id: str,
inputs: List[str],
**kwargs,
) -> Dict:
"""Generate embeddings using watsonx.ai.
Args:
model_id: The watsonx embedding model ID
inputs: List of texts to embed
**kwargs: Additional parameters
Returns:
Embeddings response
"""
headers = await self._get_headers()
client = await self._get_client()
payload = {
"model_id": model_id,
"project_id": self.project_id,
"inputs": inputs,
}
url = f"{self.base_url}/text/embeddings"
params = {"version": "2024-02-13"}
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
# Global service instance
watsonx_service = WatsonxService()

21
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,21 @@
"""Utility functions for watsonx-openai-proxy."""
from app.utils.transformers import (
transform_messages_to_watsonx,
transform_tools_to_watsonx,
transform_watsonx_to_openai_chat,
transform_watsonx_to_openai_chat_chunk,
transform_watsonx_to_openai_completion,
transform_watsonx_to_openai_embeddings,
format_sse_event,
)
__all__ = [
"transform_messages_to_watsonx",
"transform_tools_to_watsonx",
"transform_watsonx_to_openai_chat",
"transform_watsonx_to_openai_chat_chunk",
"transform_watsonx_to_openai_completion",
"transform_watsonx_to_openai_embeddings",
"format_sse_event",
]

272
app/utils/transformers.py Normal file
View File

@@ -0,0 +1,272 @@
"""Utilities for transforming between OpenAI and watsonx formats."""
import time
import uuid
from typing import Any, Dict, List, Optional
from app.models.openai_models import (
ChatMessage,
ChatCompletionChoice,
ChatCompletionUsage,
ChatCompletionResponse,
ChatCompletionChunk,
ChatCompletionChunkChoice,
ChatCompletionChunkDelta,
CompletionChoice,
CompletionResponse,
EmbeddingData,
EmbeddingUsage,
EmbeddingResponse,
)
def transform_messages_to_watsonx(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""Transform OpenAI messages to watsonx format.
Args:
messages: List of OpenAI ChatMessage objects
Returns:
List of watsonx-compatible message dicts
"""
watsonx_messages = []
for msg in messages:
watsonx_msg = {
"role": msg.role,
}
if msg.content:
watsonx_msg["content"] = msg.content
if msg.name:
watsonx_msg["name"] = msg.name
if msg.tool_calls:
watsonx_msg["tool_calls"] = msg.tool_calls
if msg.function_call:
watsonx_msg["function_call"] = msg.function_call
watsonx_messages.append(watsonx_msg)
return watsonx_messages
def transform_tools_to_watsonx(tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
"""Transform OpenAI tools to watsonx format.
Args:
tools: List of OpenAI tool definitions
Returns:
List of watsonx-compatible tool definitions
"""
if not tools:
return None
# watsonx uses similar format to OpenAI for tools
return tools
def transform_watsonx_to_openai_chat(
watsonx_response: Dict[str, Any],
model: str,
request_id: Optional[str] = None,
) -> ChatCompletionResponse:
"""Transform watsonx chat response to OpenAI format.
Args:
watsonx_response: Response from watsonx chat API
model: Model name to include in response
request_id: Optional request ID, generates one if not provided
Returns:
OpenAI-compatible ChatCompletionResponse
"""
response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
# Extract choices
choices = []
watsonx_choices = watsonx_response.get("choices", [])
for idx, choice in enumerate(watsonx_choices):
message_data = choice.get("message", {})
message = ChatMessage(
role=message_data.get("role", "assistant"),
content=message_data.get("content"),
tool_calls=message_data.get("tool_calls"),
function_call=message_data.get("function_call"),
)
choices.append(
ChatCompletionChoice(
index=idx,
message=message,
finish_reason=choice.get("finish_reason"),
)
)
# Extract usage
usage_data = watsonx_response.get("usage", {})
usage = ChatCompletionUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_id,
created=created,
model=model,
choices=choices,
usage=usage,
)
def transform_watsonx_to_openai_chat_chunk(
watsonx_chunk: Dict[str, Any],
model: str,
request_id: str,
) -> ChatCompletionChunk:
"""Transform watsonx streaming chunk to OpenAI format.
Args:
watsonx_chunk: Streaming chunk from watsonx
model: Model name to include in response
request_id: Request ID for this stream
Returns:
OpenAI-compatible ChatCompletionChunk
"""
created = int(time.time())
# Extract choices
choices = []
watsonx_choices = watsonx_chunk.get("choices", [])
for idx, choice in enumerate(watsonx_choices):
delta_data = choice.get("delta", {})
delta = ChatCompletionChunkDelta(
role=delta_data.get("role"),
content=delta_data.get("content"),
tool_calls=delta_data.get("tool_calls"),
function_call=delta_data.get("function_call"),
)
choices.append(
ChatCompletionChunkChoice(
index=idx,
delta=delta,
finish_reason=choice.get("finish_reason"),
)
)
return ChatCompletionChunk(
id=request_id,
created=created,
model=model,
choices=choices,
)
def transform_watsonx_to_openai_completion(
watsonx_response: Dict[str, Any],
model: str,
request_id: Optional[str] = None,
) -> CompletionResponse:
"""Transform watsonx text generation response to OpenAI completion format.
Args:
watsonx_response: Response from watsonx text generation API
model: Model name to include in response
request_id: Optional request ID, generates one if not provided
Returns:
OpenAI-compatible CompletionResponse
"""
response_id = request_id or f"cmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
# Extract results
results = watsonx_response.get("results", [])
choices = []
for idx, result in enumerate(results):
choices.append(
CompletionChoice(
text=result.get("generated_text", ""),
index=idx,
finish_reason=result.get("stop_reason"),
)
)
# Extract usage
usage_data = watsonx_response.get("usage", {})
usage = ChatCompletionUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("generated_tokens", 0),
total_tokens=usage_data.get("prompt_tokens", 0) + usage_data.get("generated_tokens", 0),
)
return CompletionResponse(
id=response_id,
created=created,
model=model,
choices=choices,
usage=usage,
)
def transform_watsonx_to_openai_embeddings(
watsonx_response: Dict[str, Any],
model: str,
) -> EmbeddingResponse:
"""Transform watsonx embeddings response to OpenAI format.
Args:
watsonx_response: Response from watsonx embeddings API
model: Model name to include in response
Returns:
OpenAI-compatible EmbeddingResponse
"""
# Extract results
results = watsonx_response.get("results", [])
data = []
for idx, result in enumerate(results):
embedding = result.get("embedding", [])
data.append(
EmbeddingData(
embedding=embedding,
index=idx,
)
)
# Calculate usage
input_token_count = watsonx_response.get("input_token_count", 0)
usage = EmbeddingUsage(
prompt_tokens=input_token_count,
total_tokens=input_token_count,
)
return EmbeddingResponse(
data=data,
model=model,
usage=usage,
)
def format_sse_event(data: str) -> str:
"""Format data as Server-Sent Event.
Args:
data: JSON string to send
Returns:
Formatted SSE string
"""
return f"data: {data}\n\n"

27
docker-compose.yml Normal file
View File

@@ -0,0 +1,27 @@
version: '3.8'
services:
watsonx-openai-proxy:
build: .
ports:
- "8000:8000"
environment:
- IBM_CLOUD_API_KEY=${IBM_CLOUD_API_KEY}
- WATSONX_PROJECT_ID=${WATSONX_PROJECT_ID}
- WATSONX_CLUSTER=${WATSONX_CLUSTER:-us-south}
- HOST=0.0.0.0
- PORT=8000
- LOG_LEVEL=${LOG_LEVEL:-info}
- API_KEY=${API_KEY:-}
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-*}
- MODEL_MAP_GPT4=${MODEL_MAP_GPT4:-ibm/granite-4-h-small}
- MODEL_MAP_GPT35=${MODEL_MAP_GPT35:-ibm/granite-3-8b-instruct}
- MODEL_MAP_GPT4_TURBO=${MODEL_MAP_GPT4_TURBO:-meta-llama/llama-3-3-70b-instruct}
- MODEL_MAP_TEXT_EMBEDDING_ADA_002=${MODEL_MAP_TEXT_EMBEDDING_ADA_002:-ibm/slate-125m-english-rtrvr}
restart: unless-stopped
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8000/health')"]
interval: 30s
timeout: 10s
retries: 3
start_period: 5s

183
example_usage.py Normal file
View File

@@ -0,0 +1,183 @@
"""Example usage of watsonx-openai-proxy with OpenAI Python SDK."""
import os
from openai import OpenAI
# Configure the client to use the proxy
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key=os.getenv("API_KEY", "not-needed-if-proxy-has-no-auth"),
)
def example_chat_completion():
"""Example: Basic chat completion."""
print("\n=== Chat Completion Example ===")
response = client.chat.completions.create(
model="ibm/granite-3-8b-instruct",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
],
temperature=0.7,
max_tokens=100,
)
print(f"Response: {response.choices[0].message.content}")
print(f"Tokens used: {response.usage.total_tokens}")
def example_streaming_chat():
"""Example: Streaming chat completion."""
print("\n=== Streaming Chat Example ===")
stream = client.chat.completions.create(
model="ibm/granite-3-8b-instruct",
messages=[
{"role": "user", "content": "Tell me a short story about a robot."},
],
stream=True,
max_tokens=200,
)
print("Response: ", end="", flush=True)
for chunk in stream:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="", flush=True)
print()
def example_with_model_mapping():
"""Example: Using mapped model names."""
print("\n=== Model Mapping Example ===")
# If you configured MODEL_MAP_GPT4=ibm/granite-4-h-small in .env
# you can use "gpt-4" and it will be mapped automatically
response = client.chat.completions.create(
model="gpt-4", # This gets mapped to ibm/granite-4-h-small
messages=[
{"role": "user", "content": "Explain quantum computing in one sentence."},
],
max_tokens=50,
)
print(f"Response: {response.choices[0].message.content}")
def example_embeddings():
"""Example: Generate embeddings."""
print("\n=== Embeddings Example ===")
response = client.embeddings.create(
model="ibm/slate-125m-english-rtrvr",
input=[
"The quick brown fox jumps over the lazy dog.",
"Machine learning is a subset of artificial intelligence.",
],
)
print(f"Generated {len(response.data)} embeddings")
print(f"Embedding dimension: {len(response.data[0].embedding)}")
print(f"First embedding (first 5 values): {response.data[0].embedding[:5]}")
def example_list_models():
"""Example: List available models."""
print("\n=== List Models Example ===")
models = client.models.list()
print(f"Available models: {len(models.data)}")
print("\nFirst 5 models:")
for model in models.data[:5]:
print(f" - {model.id}")
def example_completion_legacy():
"""Example: Legacy completion endpoint."""
print("\n=== Legacy Completion Example ===")
response = client.completions.create(
model="ibm/granite-3-8b-instruct",
prompt="Once upon a time, in a land far away,",
max_tokens=50,
temperature=0.8,
)
print(f"Completion: {response.choices[0].text}")
def example_with_functions():
"""Example: Function calling (if supported by model)."""
print("\n=== Function Calling Example ===")
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
}
]
try:
response = client.chat.completions.create(
model="ibm/granite-3-8b-instruct",
messages=[
{"role": "user", "content": "What's the weather like in Boston?"},
],
tools=tools,
tool_choice="auto",
)
message = response.choices[0].message
if message.tool_calls:
print(f"Function called: {message.tool_calls[0].function.name}")
print(f"Arguments: {message.tool_calls[0].function.arguments}")
else:
print(f"Response: {message.content}")
except Exception as e:
print(f"Function calling may not be supported by this model: {e}")
if __name__ == "__main__":
print("watsonx-openai-proxy Usage Examples")
print("=" * 50)
try:
# Run examples
example_chat_completion()
example_streaming_chat()
example_embeddings()
example_list_models()
example_completion_legacy()
# Optional examples (may require specific configuration)
# example_with_model_mapping()
# example_with_functions()
print("\n" + "=" * 50)
print("All examples completed successfully!")
except Exception as e:
print(f"\nError: {e}")
print("\nMake sure:")
print("1. The proxy server is running (python -m app.main)")
print("2. Your .env file is configured correctly")
print("3. You have the OpenAI Python SDK installed (pip install openai)")

7
requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
fastapi==0.115.0
uvicorn[standard]==0.32.0
pydantic==2.9.2
pydantic-settings==2.6.0
httpx==0.27.2
python-dotenv==1.0.1
python-multipart==0.0.12

87
tests/test_basic.py Normal file
View File

@@ -0,0 +1,87 @@
"""Basic tests for watsonx-openai-proxy."""
import pytest
from fastapi.testclient import TestClient
from app.main import app
client = TestClient(app)
def test_health_check():
"""Test the health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "cluster" in data
def test_root_endpoint():
"""Test the root endpoint."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["service"] == "watsonx-openai-proxy"
assert "endpoints" in data
def test_list_models():
"""Test listing available models."""
response = client.get("/v1/models")
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) > 0
assert all(model["object"] == "model" for model in data["data"])
def test_retrieve_model():
"""Test retrieving a specific model."""
response = client.get("/v1/models/ibm/granite-3-8b-instruct")
assert response.status_code == 200
data = response.json()
assert data["id"] == "ibm/granite-3-8b-instruct"
assert data["object"] == "model"
def test_retrieve_nonexistent_model():
"""Test retrieving a model that doesn't exist."""
response = client.get("/v1/models/nonexistent-model")
assert response.status_code == 404
# Note: The following tests require valid IBM Cloud credentials
# and should be run with pytest markers or in integration tests
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
def test_chat_completion():
"""Test chat completion endpoint."""
response = client.post(
"/v1/chat/completions",
json={
"model": "ibm/granite-3-8b-instruct",
"messages": [
{"role": "user", "content": "Hello!"}
],
},
)
assert response.status_code == 200
data = response.json()
assert "choices" in data
assert len(data["choices"]) > 0
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
def test_embeddings():
"""Test embeddings endpoint."""
response = client.post(
"/v1/embeddings",
json={
"model": "ibm/slate-125m-english-rtrvr",
"input": "Test text",
},
)
assert response.status_code == 200
data = response.json()
assert "data" in data
assert len(data["data"]) > 0