commit 2e2b81743527ee513b4cf9f293a76e5928972bc1 Author: Michael Date: Mon Feb 23 09:59:52 2026 -0500 Add AGENTS.md documentation for AI agent guidance diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..7f819b9 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,223 @@ +# AGENTS.md + +This file provides guidance to agents when working with code in this repository. + +## Project Overview + +**watsonx-openai-proxy** is an OpenAI-compatible API proxy for IBM watsonx.ai. It enables any tool or application that supports the OpenAI API format to seamlessly work with watsonx.ai models. + +### Core Purpose +- Provide drop-in replacement for OpenAI API endpoints +- Translate OpenAI API requests to watsonx.ai API calls +- Handle IBM Cloud authentication and token management automatically +- Support streaming responses via Server-Sent Events (SSE) + +### Technology Stack +- **Framework**: FastAPI (async web framework) +- **Language**: Python 3.9+ +- **HTTP Client**: httpx (async HTTP client) +- **Validation**: Pydantic v2 (data validation and settings) +- **Server**: uvicorn (ASGI server) + +### Architecture + +The codebase follows a clean, modular architecture: + +``` +app/ +├── main.py # FastAPI app initialization, middleware, lifespan management +├── config.py # Settings management, model mapping, environment variables +├── routers/ # API endpoint handlers (chat, completions, embeddings, models) +├── services/ # Business logic (watsonx_service for API interactions) +├── models/ # Pydantic models for OpenAI-compatible schemas +└── utils/ # Helper functions (request/response transformers) +``` + +**Key Design Patterns**: +- **Service Layer**: `watsonx_service.py` encapsulates all watsonx.ai API interactions +- **Transformer Pattern**: `transformers.py` handles bidirectional conversion between OpenAI and watsonx formats +- **Singleton Services**: Global service instances (`watsonx_service`, `settings`) for shared state +- **Async/Await**: All I/O operations are asynchronous for better performance +- **Middleware**: Custom authentication middleware for optional API key validation + +## Building and Running + +### Prerequisites +```bash +# Python 3.9 or higher required +python --version + +# IBM Cloud credentials needed: +# - IBM_CLOUD_API_KEY +# - WATSONX_PROJECT_ID +``` + +### Installation +```bash +# Install dependencies +pip install -r requirements.txt + +# Configure environment +cp .env.example .env +# Edit .env with your IBM Cloud credentials +``` + +### Running the Server +```bash +# Development (with auto-reload) +uvicorn app.main:app --reload --host 0.0.0.0 --port 8000 + +# Production (with workers) +uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4 + +# Using Python module +python -m app.main +``` + +### Docker Deployment +```bash +# Build image +docker build -t watsonx-openai-proxy . + +# Run container +docker run -p 8000:8000 --env-file .env watsonx-openai-proxy + +# Using docker-compose +docker-compose up +``` + +### Testing +```bash +# Install test dependencies +pip install pytest pytest-asyncio httpx + +# Run tests +pytest tests/ + +# Run with coverage +pytest tests/ --cov=app +``` + +## Development Conventions + +### Code Style +- **Async First**: Use `async`/`await` for all I/O operations (HTTP requests, file operations) +- **Type Hints**: All functions should have type annotations for parameters and return values +- **Docstrings**: Use Google-style docstrings for functions and classes +- **Logging**: Use the `logging` module with appropriate log levels (info, warning, error) + +### Error Handling +- Catch exceptions at router level and return OpenAI-compatible error responses +- Use `HTTPException` with proper status codes and error details +- Log errors with full context using `logger.error(..., exc_info=True)` +- Return structured error responses matching OpenAI's error format + +### Configuration Management +- All configuration via environment variables (`.env` file) +- Use `pydantic-settings` for type-safe configuration +- Model mapping via `MODEL_MAP_*` environment variables +- Settings accessed through global `settings` instance + +### Token Management +- Bearer tokens automatically refreshed every 50 minutes (expire at 60 minutes) +- Token refresh on 401 errors from watsonx.ai +- Thread-safe token refresh using `asyncio.Lock` +- Initial token obtained during application startup + +### API Compatibility +- Maintain strict OpenAI API compatibility in request/response formats +- Use Pydantic models from `openai_models.py` for validation +- Transform requests/responses using functions in `transformers.py` +- Support both streaming and non-streaming responses + +### Adding New Endpoints +1. Create router in `app/routers/` (e.g., `new_endpoint.py`) +2. Define Pydantic models in `app/models/openai_models.py` +3. Add transformation logic in `app/utils/transformers.py` +4. Add watsonx.ai API method in `app/services/watsonx_service.py` +5. Register router in `app/main.py` using `app.include_router()` + +### Streaming Responses +- Use `StreamingResponse` with `media_type="text/event-stream"` +- Format chunks as Server-Sent Events using `format_sse_event()` +- Always send `[DONE]` message at the end of stream +- Handle errors gracefully and send error events in SSE format + +### Model Mapping +- Map OpenAI model names to watsonx models via environment variables +- Format: `MODEL_MAP_=` +- Example: `MODEL_MAP_GPT4=ibm/granite-4-h-small` +- Mapping applied in `settings.map_model()` before API calls + +### Security Considerations +- Optional API key authentication via `API_KEY` environment variable +- Middleware validates Bearer token in Authorization header +- IBM Cloud API key stored securely in environment variables +- CORS configured via `ALLOWED_ORIGINS` (default: `*`) + +### Logging Best Practices +- Use structured logging with context (model names, request IDs) +- Log level controlled by `LOG_LEVEL` environment variable +- Log token refresh events at INFO level +- Log API errors at ERROR level with full traceback +- Include request/response details for debugging + +### Dependencies +- Keep `requirements.txt` minimal and pinned to specific versions +- FastAPI and Pydantic are core dependencies - avoid breaking changes +- httpx for async HTTP - prefer over requests/aiohttp +- Use `uvicorn[standard]` for production-ready server + +## Important Implementation Notes + +### watsonx.ai API Specifics +- Base URL format: `https://{cluster}.ml.cloud.ibm.com/ml/v1` +- API version parameter: `version=2024-02-13` (required on all requests) +- Chat endpoint: `/text/chat` (non-streaming) or `/text/chat_stream` (streaming) +- Text generation: `/text/generation` +- Embeddings: `/text/embeddings` + +### Request/Response Transformation +- OpenAI messages → watsonx messages: Direct mapping with role/content +- watsonx responses → OpenAI format: Extract choices, usage, and metadata +- Streaming chunks: Parse SSE format, transform delta objects +- Generate unique IDs: `chatcmpl-{uuid}` for chat, `cmpl-{uuid}` for completions + +### Common Pitfalls +- Don't forget to refresh tokens before they expire (50-minute interval) +- Always close httpx client on shutdown (`await watsonx_service.close()`) +- Handle both string and list formats for `stop` parameter +- Validate model IDs exist in watsonx.ai before making requests +- Set appropriate timeouts for long-running generation requests (300s default) + +### Performance Optimization +- Reuse httpx client instance (don't create per request) +- Use connection pooling (httpx default behavior) +- Consider worker processes for production (`--workers 4`) +- Monitor token refresh to avoid rate limiting + +## Environment Variables Reference + +### Required +- `IBM_CLOUD_API_KEY`: IBM Cloud API key for authentication +- `WATSONX_PROJECT_ID`: watsonx.ai project ID + +### Optional +- `WATSONX_CLUSTER`: Region (default: `us-south`) +- `HOST`: Server host (default: `0.0.0.0`) +- `PORT`: Server port (default: `8000`) +- `LOG_LEVEL`: Logging level (default: `info`) +- `API_KEY`: Optional proxy authentication key +- `ALLOWED_ORIGINS`: CORS origins (default: `*`) +- `MODEL_MAP_*`: Model name mappings + +## API Endpoints + +- `GET /` - API information and available endpoints +- `GET /health` - Health check (bypasses authentication) +- `GET /docs` - Interactive Swagger UI documentation +- `POST /v1/chat/completions` - Chat completions (streaming supported) +- `POST /v1/completions` - Text completions (legacy) +- `POST /v1/embeddings` - Generate embeddings +- `GET /v1/models` - List available models +- `GET /v1/models/{model_id}` - Get specific model info diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..1876c37 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY app ./app + +# Expose port +EXPOSE 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD python -c "import httpx; httpx.get('http://localhost:8000/health')" + +# Run the application +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..8707495 --- /dev/null +++ b/README.md @@ -0,0 +1,353 @@ +# watsonx-openai-proxy + +OpenAI-compatible API proxy for IBM watsonx.ai. This proxy allows you to use watsonx.ai models with any tool or application that supports the OpenAI API format. + +## Features + +- ✅ **Full OpenAI API Compatibility**: Drop-in replacement for OpenAI API +- ✅ **Chat Completions**: `/v1/chat/completions` with streaming support +- ✅ **Text Completions**: `/v1/completions` (legacy endpoint) +- ✅ **Embeddings**: `/v1/embeddings` for text embeddings +- ✅ **Model Listing**: `/v1/models` endpoint +- ✅ **Streaming Support**: Server-Sent Events (SSE) for real-time responses +- ✅ **Model Mapping**: Map OpenAI model names to watsonx models +- ✅ **Automatic Token Management**: Handles IBM Cloud authentication automatically +- ✅ **CORS Support**: Configurable cross-origin resource sharing +- ✅ **Optional API Key Authentication**: Secure your proxy with an API key + +## Quick Start + +### Prerequisites + +- Python 3.9 or higher +- IBM Cloud account with watsonx.ai access +- IBM Cloud API key +- watsonx.ai Project ID + +### Installation + +1. Clone or download this directory: + +```bash +cd watsonx-openai-proxy +``` + +2. Install dependencies: + +```bash +pip install -r requirements.txt +``` + +3. Configure environment variables: + +```bash +cp .env.example .env +# Edit .env with your credentials +``` + +4. Run the server: + +```bash +python -m app.main +``` + +Or with uvicorn: + +```bash +uvicorn app.main:app --host 0.0.0.0 --port 8000 +``` + +The server will start at `http://localhost:8000` + +## Configuration + +### Environment Variables + +Create a `.env` file with the following variables: + +```bash +# Required: IBM Cloud Configuration +IBM_CLOUD_API_KEY=your_ibm_cloud_api_key_here +WATSONX_PROJECT_ID=your_watsonx_project_id_here +WATSONX_CLUSTER=us-south # Options: us-south, eu-de, eu-gb, jp-tok, au-syd, ca-tor + +# Optional: Server Configuration +HOST=0.0.0.0 +PORT=8000 +LOG_LEVEL=info + +# Optional: API Key for Proxy Authentication +API_KEY=your_optional_api_key_for_proxy_authentication + +# Optional: CORS Configuration +ALLOWED_ORIGINS=* # Comma-separated or * for all + +# Optional: Model Mapping +MODEL_MAP_GPT4=ibm/granite-4-h-small +MODEL_MAP_GPT35=ibm/granite-3-8b-instruct +MODEL_MAP_GPT4_TURBO=meta-llama/llama-3-3-70b-instruct +MODEL_MAP_TEXT_EMBEDDING_ADA_002=ibm/slate-125m-english-rtrvr +``` + +### Model Mapping + +You can map OpenAI model names to watsonx models using environment variables: + +```bash +MODEL_MAP_= +``` + +For example: +- `MODEL_MAP_GPT4=ibm/granite-4-h-small` maps `gpt-4` to `ibm/granite-4-h-small` +- `MODEL_MAP_GPT35_TURBO=ibm/granite-3-8b-instruct` maps `gpt-3.5-turbo` to `ibm/granite-3-8b-instruct` + +## Usage + +### With OpenAI Python SDK + +```python +from openai import OpenAI + +# Point to your proxy +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="your-proxy-api-key" # Optional, if you set API_KEY in .env +) + +# Use as normal +response = client.chat.completions.create( + model="ibm/granite-3-8b-instruct", # Or use mapped name like "gpt-4" + messages=[ + {"role": "user", "content": "Hello, how are you?"} + ] +) + +print(response.choices[0].message.content) +``` + +### With Streaming + +```python +stream = client.chat.completions.create( + model="ibm/granite-3-8b-instruct", + messages=[{"role": "user", "content": "Tell me a story"}], + stream=True +) + +for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="") +``` + +### With cURL + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer your-proxy-api-key" \ + -d '{ + "model": "ibm/granite-3-8b-instruct", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +### Embeddings + +```python +response = client.embeddings.create( + model="ibm/slate-125m-english-rtrvr", + input="Your text to embed" +) + +print(response.data[0].embedding) +``` + +## Available Endpoints + +- `GET /` - API information +- `GET /health` - Health check +- `GET /docs` - Interactive API documentation (Swagger UI) +- `POST /v1/chat/completions` - Chat completions +- `POST /v1/completions` - Text completions (legacy) +- `POST /v1/embeddings` - Generate embeddings +- `GET /v1/models` - List available models +- `GET /v1/models/{model_id}` - Get model information + +## Supported Models + +The proxy supports all watsonx.ai models available in your project, including: + +### Chat Models +- IBM Granite models (3.x, 4.x series) +- Meta Llama models (3.x, 4.x series) +- Mistral models +- Other models available on watsonx.ai + +### Embedding Models +- `ibm/slate-125m-english-rtrvr` +- `ibm/slate-30m-english-rtrvr` + +See `/v1/models` endpoint for the complete list. + +## Authentication + +### Proxy Authentication (Optional) + +If you set `API_KEY` in your `.env` file, clients must provide it: + +```python +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="your-proxy-api-key" +) +``` + +### IBM Cloud Authentication + +The proxy handles IBM Cloud authentication automatically using your `IBM_CLOUD_API_KEY`. Bearer tokens are: +- Automatically obtained on startup +- Refreshed every 50 minutes (tokens expire after 60 minutes) +- Refreshed on 401 errors + +## Deployment + +### Docker (Recommended) + +Create a `Dockerfile`: + +```dockerfile +FROM python:3.11-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY app ./app +COPY .env . + +EXPOSE 8000 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +``` + +Build and run: + +```bash +docker build -t watsonx-openai-proxy . +docker run -p 8000:8000 --env-file .env watsonx-openai-proxy +``` + +### Production Deployment + +For production, consider: + +1. **Use a production ASGI server**: The included uvicorn is suitable, but configure workers: + ```bash + uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4 + ``` + +2. **Set up HTTPS**: Use a reverse proxy like nginx or Caddy + +3. **Configure CORS**: Set `ALLOWED_ORIGINS` to specific domains + +4. **Enable API key authentication**: Set `API_KEY` in environment + +5. **Monitor logs**: Set `LOG_LEVEL=info` or `warning` in production + +6. **Use environment secrets**: Don't commit `.env` file, use secret management + +## Troubleshooting + +### 401 Unauthorized + +- Check that `IBM_CLOUD_API_KEY` is valid +- Verify your IBM Cloud account has watsonx.ai access +- Check server logs for token refresh errors + +### Model Not Found + +- Verify the model ID exists in watsonx.ai +- Check that your project has access to the model +- Use `/v1/models` endpoint to see available models + +### Connection Errors + +- Verify `WATSONX_CLUSTER` matches your project's region +- Check firewall/network settings +- Ensure watsonx.ai services are accessible + +### Streaming Issues + +- Some models may not support streaming +- Check client library supports SSE (Server-Sent Events) +- Verify network doesn't buffer streaming responses + +## Development + +### Running Tests + +```bash +# Install dev dependencies +pip install pytest pytest-asyncio httpx + +# Run tests +pytest tests/ +``` + +### Code Structure + +``` +watsonx-openai-proxy/ +├── app/ +│ ├── main.py # FastAPI application +│ ├── config.py # Configuration management +│ ├── routers/ # API endpoint routers +│ │ ├── chat.py # Chat completions +│ │ ├── completions.py # Text completions +│ │ ├── embeddings.py # Embeddings +│ │ └── models.py # Model listing +│ ├── services/ # Business logic +│ │ └── watsonx_service.py # watsonx.ai API client +│ ├── models/ # Pydantic models +│ │ └── openai_models.py # OpenAI-compatible schemas +│ └── utils/ # Utilities +│ └── transformers.py # Request/response transformers +├── tests/ # Test files +├── requirements.txt # Python dependencies +├── .env.example # Environment template +└── README.md # This file +``` + +## Contributing + +Contributions are welcome! Please: + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Add tests if applicable +5. Submit a pull request + +## License + +Apache 2.0 License - See LICENSE file for details. + +## Related Projects + +- [watsonx-unofficial-aisdk-provider](../wxai-provider/) - Vercel AI SDK provider for watsonx.ai +- [OpenCode watsonx plugin](../.opencode/plugins/) - Token management plugin for OpenCode + +## Disclaimer + +This is **not an official IBM product**. It's a community-maintained proxy for integrating watsonx.ai with OpenAI-compatible tools. watsonx.ai is a trademark of IBM. + +## Support + +For issues and questions: +- Check the [Troubleshooting](#troubleshooting) section +- Review server logs (`LOG_LEVEL=debug` for detailed logs) +- Open an issue in the repository +- Consult [IBM watsonx.ai documentation](https://www.ibm.com/docs/en/watsonx-as-a-service) diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..989dfbc --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,3 @@ +"""watsonx-openai-proxy application package.""" + +__version__ = "1.0.0" diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..527e373 --- /dev/null +++ b/app/config.py @@ -0,0 +1,91 @@ +"""Configuration management for watsonx-openai-proxy.""" + +import os +from typing import Dict, Optional +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application settings loaded from environment variables.""" + + # IBM Cloud Configuration + ibm_cloud_api_key: str + watsonx_project_id: str + watsonx_cluster: str = "us-south" + + # Server Configuration + host: str = "0.0.0.0" + port: int = 8000 + log_level: str = "info" + + # API Configuration + api_key: Optional[str] = None + allowed_origins: str = "*" + + # Token Management + token_refresh_interval: int = 3000 # 50 minutes in seconds + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="allow", # Allow extra fields for model mapping + ) + + @property + def watsonx_base_url(self) -> str: + """Construct the watsonx.ai base URL from cluster.""" + return f"https://{self.watsonx_cluster}.ml.cloud.ibm.com/ml/v1" + + @property + def cors_origins(self) -> list[str]: + """Parse CORS origins from comma-separated string.""" + if self.allowed_origins == "*": + return ["*"] + return [origin.strip() for origin in self.allowed_origins.split(",")] + + def get_model_mapping(self) -> Dict[str, str]: + """Extract model mappings from environment variables. + + Looks for variables like MODEL_MAP_GPT4=ibm/granite-4-h-small + and creates a mapping dict. + """ + import re + mapping = {} + + # Check os.environ first + for key, value in os.environ.items(): + if key.startswith("MODEL_MAP_"): + model_name = key.replace("MODEL_MAP_", "") + model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name) + openai_model = model_name.lower().replace("_", "-") + mapping[openai_model] = value + + # Also check pydantic's extra fields (from .env file) + # These come in as lowercase: model_map_gpt4 instead of MODEL_MAP_GPT4 + extra = getattr(self, '__pydantic_extra__', {}) or {} + for key, value in extra.items(): + if key.startswith("model_map_"): + # Convert back to uppercase for processing + model_name = key.replace("model_map_", "").upper() + model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name) + openai_model = model_name.lower().replace("_", "-") + mapping[openai_model] = value + + return mapping + + def map_model(self, openai_model: str) -> str: + """Map an OpenAI model name to a watsonx model ID. + + Args: + openai_model: OpenAI model name (e.g., "gpt-4", "gpt-3.5-turbo") + + Returns: + Corresponding watsonx model ID, or the original name if no mapping exists + """ + model_map = self.get_model_mapping() + return model_map.get(openai_model, openai_model) + + +# Global settings instance +settings = Settings() diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..0932036 --- /dev/null +++ b/app/main.py @@ -0,0 +1,157 @@ +"""Main FastAPI application for watsonx-openai-proxy.""" + +import logging +from contextlib import asynccontextmanager +from fastapi import FastAPI, Request, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from app.config import settings +from app.routers import chat, completions, embeddings, models +from app.services.watsonx_service import watsonx_service + +# Configure logging +logging.basicConfig( + level=getattr(logging, settings.log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan events.""" + # Startup + logger.info("Starting watsonx-openai-proxy...") + logger.info(f"Cluster: {settings.watsonx_cluster}") + logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...") + + # Initialize token + try: + await watsonx_service._refresh_token() + logger.info("Initial bearer token obtained successfully") + except Exception as e: + logger.error(f"Failed to obtain initial bearer token: {e}") + raise + + yield + + # Shutdown + logger.info("Shutting down watsonx-openai-proxy...") + await watsonx_service.close() + + +# Create FastAPI app +app = FastAPI( + title="watsonx-openai-proxy", + description="OpenAI-compatible API proxy for IBM watsonx.ai", + version="1.0.0", + lifespan=lifespan, +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Optional API key authentication middleware +@app.middleware("http") +async def authenticate(request: Request, call_next): + """Authenticate requests if API key is configured.""" + if settings.api_key: + # Skip authentication for health check + if request.url.path == "/health": + return await call_next(request) + + # Check Authorization header + auth_header = request.headers.get("Authorization") + if not auth_header: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": { + "message": "Missing Authorization header", + "type": "authentication_error", + "code": "missing_authorization", + } + }, + ) + + # Validate API key + if not auth_header.startswith("Bearer "): + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": { + "message": "Invalid Authorization header format", + "type": "authentication_error", + "code": "invalid_authorization_format", + } + }, + ) + + token = auth_header[7:] # Remove "Bearer " prefix + if token != settings.api_key: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "error": { + "message": "Invalid API key", + "type": "authentication_error", + "code": "invalid_api_key", + } + }, + ) + + return await call_next(request) + + +# Include routers +app.include_router(chat.router, tags=["Chat"]) +app.include_router(completions.router, tags=["Completions"]) +app.include_router(embeddings.router, tags=["Embeddings"]) +app.include_router(models.router, tags=["Models"]) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return { + "status": "healthy", + "service": "watsonx-openai-proxy", + "cluster": settings.watsonx_cluster, + } + + +@app.get("/") +async def root(): + """Root endpoint with API information.""" + return { + "service": "watsonx-openai-proxy", + "description": "OpenAI-compatible API proxy for IBM watsonx.ai", + "version": "1.0.0", + "endpoints": { + "chat": "/v1/chat/completions", + "completions": "/v1/completions", + "embeddings": "/v1/embeddings", + "models": "/v1/models", + "health": "/health", + }, + "documentation": "/docs", + } + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + log_level=settings.log_level, + reload=False, + ) diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..a558324 --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1,31 @@ +"""OpenAI-compatible data models.""" + +from app.models.openai_models import ( + ChatMessage, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChunk, + CompletionRequest, + CompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ModelsResponse, + ModelInfo, + ErrorResponse, + ErrorDetail, +) + +__all__ = [ + "ChatMessage", + "ChatCompletionRequest", + "ChatCompletionResponse", + "ChatCompletionChunk", + "CompletionRequest", + "CompletionResponse", + "EmbeddingRequest", + "EmbeddingResponse", + "ModelsResponse", + "ModelInfo", + "ErrorResponse", + "ErrorDetail", +] diff --git a/app/models/openai_models.py b/app/models/openai_models.py new file mode 100644 index 0000000..71534ba --- /dev/null +++ b/app/models/openai_models.py @@ -0,0 +1,213 @@ +"""OpenAI-compatible request and response models.""" + +from typing import Any, Dict, List, Literal, Optional, Union +from pydantic import BaseModel, Field + + +# ============================================================================ +# Chat Completions Models +# ============================================================================ + +class ChatMessage(BaseModel): + """A chat message in the conversation.""" + role: Literal["system", "user", "assistant", "function", "tool"] + content: Optional[str] = None + name: Optional[str] = None + function_call: Optional[Dict[str, Any]] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + + +class FunctionCall(BaseModel): + """Function call specification.""" + name: str + arguments: str + + +class ToolCall(BaseModel): + """Tool call specification.""" + id: str + type: Literal["function"] + function: FunctionCall + + +class ChatCompletionRequest(BaseModel): + """OpenAI chat completion request.""" + model: str + messages: List[ChatMessage] + temperature: Optional[float] = Field(default=1.0, ge=0, le=2) + top_p: Optional[float] = Field(default=1.0, ge=0, le=1) + n: Optional[int] = Field(default=1, ge=1) + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = Field(default=None, ge=1) + presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2) + frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2) + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + functions: Optional[List[Dict[str, Any]]] = None + function_call: Optional[Union[str, Dict[str, str]]] = None + tools: Optional[List[Dict[str, Any]]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + + +class ChatCompletionChoice(BaseModel): + """A single chat completion choice.""" + index: int + message: ChatMessage + finish_reason: Optional[str] = None + logprobs: Optional[Dict[str, Any]] = None + + +class ChatCompletionUsage(BaseModel): + """Token usage information.""" + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + """OpenAI chat completion response.""" + id: str + object: Literal["chat.completion"] = "chat.completion" + created: int + model: str + choices: List[ChatCompletionChoice] + usage: ChatCompletionUsage + system_fingerprint: Optional[str] = None + + +class ChatCompletionChunkDelta(BaseModel): + """Delta content in streaming response.""" + role: Optional[str] = None + content: Optional[str] = None + function_call: Optional[Dict[str, Any]] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + + +class ChatCompletionChunkChoice(BaseModel): + """A single streaming chunk choice.""" + index: int + delta: ChatCompletionChunkDelta + finish_reason: Optional[str] = None + logprobs: Optional[Dict[str, Any]] = None + + +class ChatCompletionChunk(BaseModel): + """OpenAI streaming chat completion chunk.""" + id: str + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int + model: str + choices: List[ChatCompletionChunkChoice] + system_fingerprint: Optional[str] = None + + +# ============================================================================ +# Completions Models (Legacy) +# ============================================================================ + +class CompletionRequest(BaseModel): + """OpenAI completion request (legacy).""" + model: str + prompt: Union[str, List[str], List[int], List[List[int]]] + suffix: Optional[str] = None + max_tokens: Optional[int] = Field(default=16, ge=1) + temperature: Optional[float] = Field(default=1.0, ge=0, le=2) + top_p: Optional[float] = Field(default=1.0, ge=0, le=1) + n: Optional[int] = Field(default=1, ge=1) + stream: Optional[bool] = False + logprobs: Optional[int] = Field(default=None, ge=0, le=5) + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2) + frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2) + best_of: Optional[int] = Field(default=1, ge=1) + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + +class CompletionChoice(BaseModel): + """A single completion choice.""" + text: str + index: int + logprobs: Optional[Dict[str, Any]] = None + finish_reason: Optional[str] = None + + +class CompletionResponse(BaseModel): + """OpenAI completion response.""" + id: str + object: Literal["text_completion"] = "text_completion" + created: int + model: str + choices: List[CompletionChoice] + usage: ChatCompletionUsage + + +# ============================================================================ +# Embeddings Models +# ============================================================================ + +class EmbeddingRequest(BaseModel): + """OpenAI embedding request.""" + model: str + input: Union[str, List[str], List[int], List[List[int]]] + encoding_format: Optional[Literal["float", "base64"]] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + + +class EmbeddingData(BaseModel): + """A single embedding result.""" + object: Literal["embedding"] = "embedding" + embedding: List[float] + index: int + + +class EmbeddingUsage(BaseModel): + """Token usage for embeddings.""" + prompt_tokens: int + total_tokens: int + + +class EmbeddingResponse(BaseModel): + """OpenAI embedding response.""" + object: Literal["list"] = "list" + data: List[EmbeddingData] + model: str + usage: EmbeddingUsage + + +# ============================================================================ +# Models List +# ============================================================================ + +class ModelInfo(BaseModel): + """Information about a single model.""" + id: str + object: Literal["model"] = "model" + created: int + owned_by: str + + +class ModelsResponse(BaseModel): + """List of available models.""" + object: Literal["list"] = "list" + data: List[ModelInfo] + + +# ============================================================================ +# Error Models +# ============================================================================ + +class ErrorDetail(BaseModel): + """Error detail information.""" + message: str + type: str + param: Optional[str] = None + code: Optional[str] = None + + +class ErrorResponse(BaseModel): + """OpenAI error response.""" + error: ErrorDetail diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..bca9aee --- /dev/null +++ b/app/routers/__init__.py @@ -0,0 +1,5 @@ +"""API routers for watsonx-openai-proxy.""" + +from app.routers import chat, completions, embeddings, models + +__all__ = ["chat", "completions", "embeddings", "models"] diff --git a/app/routers/chat.py b/app/routers/chat.py new file mode 100644 index 0000000..f03fc1f --- /dev/null +++ b/app/routers/chat.py @@ -0,0 +1,156 @@ +"""Chat completions endpoint router.""" + +import json +import uuid +from typing import Union +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from app.models.openai_models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ErrorResponse, + ErrorDetail, +) +from app.services.watsonx_service import watsonx_service +from app.utils.transformers import ( + transform_messages_to_watsonx, + transform_tools_to_watsonx, + transform_watsonx_to_openai_chat, + transform_watsonx_to_openai_chat_chunk, + format_sse_event, +) +from app.config import settings +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/v1/chat/completions", + response_model=Union[ChatCompletionResponse, ErrorResponse], + responses={ + 200: {"model": ChatCompletionResponse}, + 400: {"model": ErrorResponse}, + 401: {"model": ErrorResponse}, + 500: {"model": ErrorResponse}, + }, +) +async def create_chat_completion( + request: ChatCompletionRequest, + http_request: Request, +): + """Create a chat completion using OpenAI-compatible API. + + This endpoint accepts OpenAI-formatted requests and translates them + to watsonx.ai API calls. + """ + try: + # Map model name if needed + watsonx_model = settings.map_model(request.model) + logger.info(f"Chat completion request: {request.model} -> {watsonx_model}") + + # Transform messages + watsonx_messages = transform_messages_to_watsonx(request.messages) + + # Transform tools if present + watsonx_tools = transform_tools_to_watsonx(request.tools) + + # Handle streaming + if request.stream: + return StreamingResponse( + stream_chat_completion( + watsonx_model, + watsonx_messages, + request, + watsonx_tools, + ), + media_type="text/event-stream", + ) + + # Non-streaming response + watsonx_response = await watsonx_service.chat_completion( + model_id=watsonx_model, + messages=watsonx_messages, + temperature=request.temperature or 1.0, + max_tokens=request.max_tokens, + top_p=request.top_p or 1.0, + stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None, + tools=watsonx_tools, + ) + + # Transform response + openai_response = transform_watsonx_to_openai_chat( + watsonx_response, + request.model, + ) + + return openai_response + + except Exception as e: + logger.error(f"Error in chat completion: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail={ + "error": { + "message": str(e), + "type": "internal_error", + "code": "internal_error", + } + }, + ) + + +async def stream_chat_completion( + watsonx_model: str, + watsonx_messages: list, + request: ChatCompletionRequest, + watsonx_tools: list = None, +): + """Stream chat completion responses. + + Args: + watsonx_model: The watsonx model ID + watsonx_messages: Transformed messages + request: Original OpenAI request + watsonx_tools: Transformed tools + + Yields: + Server-Sent Events with chat completion chunks + """ + request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + + try: + async for chunk in watsonx_service.chat_completion_stream( + model_id=watsonx_model, + messages=watsonx_messages, + temperature=request.temperature or 1.0, + max_tokens=request.max_tokens, + top_p=request.top_p or 1.0, + stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None, + tools=watsonx_tools, + ): + # Transform chunk to OpenAI format + openai_chunk = transform_watsonx_to_openai_chat_chunk( + chunk, + request.model, + request_id, + ) + + # Send as SSE + yield format_sse_event(openai_chunk.model_dump_json()) + + # Send [DONE] message + yield format_sse_event("[DONE]") + + except Exception as e: + logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True) + error_response = ErrorResponse( + error=ErrorDetail( + message=str(e), + type="internal_error", + code="stream_error", + ) + ) + yield format_sse_event(error_response.model_dump_json()) diff --git a/app/routers/completions.py b/app/routers/completions.py new file mode 100644 index 0000000..4fc39a8 --- /dev/null +++ b/app/routers/completions.py @@ -0,0 +1,109 @@ +"""Text completions endpoint router (legacy).""" + +import uuid +from typing import Union +from fastapi import APIRouter, HTTPException, Request +from app.models.openai_models import ( + CompletionRequest, + CompletionResponse, + ErrorResponse, + ErrorDetail, +) +from app.services.watsonx_service import watsonx_service +from app.utils.transformers import transform_watsonx_to_openai_completion +from app.config import settings +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/v1/completions", + response_model=Union[CompletionResponse, ErrorResponse], + responses={ + 200: {"model": CompletionResponse}, + 400: {"model": ErrorResponse}, + 401: {"model": ErrorResponse}, + 500: {"model": ErrorResponse}, + }, +) +async def create_completion( + request: CompletionRequest, + http_request: Request, +): + """Create a text completion using OpenAI-compatible API (legacy). + + This endpoint accepts OpenAI-formatted completion requests and translates + them to watsonx.ai text generation API calls. + """ + try: + # Map model name if needed + watsonx_model = settings.map_model(request.model) + logger.info(f"Completion request: {request.model} -> {watsonx_model}") + + # Handle prompt (can be string or list) + if isinstance(request.prompt, list): + if len(request.prompt) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": "Prompt cannot be empty", + "type": "invalid_request_error", + "code": "invalid_prompt", + } + }, + ) + # For now, just use the first prompt + # TODO: Handle multiple prompts with n parameter + prompt = request.prompt[0] if isinstance(request.prompt[0], str) else "" + else: + prompt = request.prompt + + # Note: Streaming not implemented for completions yet + if request.stream: + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": "Streaming not supported for completions endpoint", + "type": "invalid_request_error", + "code": "streaming_not_supported", + } + }, + ) + + # Call watsonx text generation + watsonx_response = await watsonx_service.text_generation( + model_id=watsonx_model, + prompt=prompt, + temperature=request.temperature or 1.0, + max_tokens=request.max_tokens, + top_p=request.top_p or 1.0, + stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None, + ) + + # Transform response + openai_response = transform_watsonx_to_openai_completion( + watsonx_response, + request.model, + ) + + return openai_response + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error in completion: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail={ + "error": { + "message": str(e), + "type": "internal_error", + "code": "internal_error", + } + }, + ) diff --git a/app/routers/embeddings.py b/app/routers/embeddings.py new file mode 100644 index 0000000..3491ac0 --- /dev/null +++ b/app/routers/embeddings.py @@ -0,0 +1,114 @@ +"""Embeddings endpoint router.""" + +from typing import Union +from fastapi import APIRouter, HTTPException, Request +from app.models.openai_models import ( + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + ErrorDetail, +) +from app.services.watsonx_service import watsonx_service +from app.utils.transformers import transform_watsonx_to_openai_embeddings +from app.config import settings +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.post( + "/v1/embeddings", + response_model=Union[EmbeddingResponse, ErrorResponse], + responses={ + 200: {"model": EmbeddingResponse}, + 400: {"model": ErrorResponse}, + 401: {"model": ErrorResponse}, + 500: {"model": ErrorResponse}, + }, +) +async def create_embeddings( + request: EmbeddingRequest, + http_request: Request, +): + """Create embeddings using OpenAI-compatible API. + + This endpoint accepts OpenAI-formatted embedding requests and translates + them to watsonx.ai embeddings API calls. + """ + try: + # Map model name if needed + watsonx_model = settings.map_model(request.model) + logger.info(f"Embeddings request: {request.model} -> {watsonx_model}") + + # Handle input (can be string or list) + if isinstance(request.input, str): + inputs = [request.input] + elif isinstance(request.input, list): + if len(request.input) == 0: + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": "Input cannot be empty", + "type": "invalid_request_error", + "code": "invalid_input", + } + }, + ) + # Handle list of strings or list of token IDs + if isinstance(request.input[0], str): + inputs = request.input + else: + # Token IDs not supported yet + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": "Token ID input not supported", + "type": "invalid_request_error", + "code": "unsupported_input_type", + } + }, + ) + else: + raise HTTPException( + status_code=400, + detail={ + "error": { + "message": "Invalid input type", + "type": "invalid_request_error", + "code": "invalid_input_type", + } + }, + ) + + # Call watsonx embeddings + watsonx_response = await watsonx_service.embeddings( + model_id=watsonx_model, + inputs=inputs, + ) + + # Transform response + openai_response = transform_watsonx_to_openai_embeddings( + watsonx_response, + request.model, + ) + + return openai_response + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error in embeddings: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail={ + "error": { + "message": str(e), + "type": "internal_error", + "code": "internal_error", + } + }, + ) diff --git a/app/routers/models.py b/app/routers/models.py new file mode 100644 index 0000000..65d0557 --- /dev/null +++ b/app/routers/models.py @@ -0,0 +1,120 @@ +"""Models endpoint router.""" + +import time +from fastapi import APIRouter +from app.models.openai_models import ModelsResponse, ModelInfo +from app.config import settings +import logging + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# Predefined list of available models +# This can be extended or made dynamic based on watsonx.ai API +AVAILABLE_MODELS = [ + # Granite Models + "ibm/granite-3-1-8b-base", + "ibm/granite-3-2-8b-instruct", + "ibm/granite-3-3-8b-instruct", + "ibm/granite-3-8b-instruct", + "ibm/granite-4-h-small", + "ibm/granite-8b-code-instruct", + + # Llama Models + "meta-llama/llama-3-1-70b-gptq", + "meta-llama/llama-3-1-8b", + "meta-llama/llama-3-2-11b-vision-instruct", + "meta-llama/llama-3-2-90b-vision-instruct", + "meta-llama/llama-3-3-70b-instruct", + "meta-llama/llama-3-405b-instruct", + "meta-llama/llama-4-maverick-17b-128e-instruct-fp8", + + # Mistral Models + "mistral-large-2512", + "mistralai/mistral-medium-2505", + "mistralai/mistral-small-3-1-24b-instruct-2503", + + # Other Models + "openai/gpt-oss-120b", + + # Embedding Models + "ibm/slate-125m-english-rtrvr", + "ibm/slate-30m-english-rtrvr", +] + + +@router.get( + "/v1/models", + response_model=ModelsResponse, +) +async def list_models(): + """List available models in OpenAI-compatible format. + + Returns a list of models that can be used with the API. + Includes both the actual watsonx model IDs and any mapped names. + """ + created_time = int(time.time()) + models = [] + + # Add all available watsonx models + for model_id in AVAILABLE_MODELS: + models.append( + ModelInfo( + id=model_id, + created=created_time, + owned_by="ibm-watsonx", + ) + ) + + # Add mapped model names (e.g., gpt-4 -> ibm/granite-4-h-small) + model_mapping = settings.get_model_mapping() + for openai_name, watsonx_id in model_mapping.items(): + if watsonx_id in AVAILABLE_MODELS: + models.append( + ModelInfo( + id=openai_name, + created=created_time, + owned_by="ibm-watsonx", + ) + ) + + return ModelsResponse(data=models) + + +@router.get( + "/v1/models/{model_id}", + response_model=ModelInfo, +) +async def retrieve_model(model_id: str): + """Retrieve information about a specific model. + + Args: + model_id: The model ID to retrieve + + Returns: + Model information + """ + # Map the model if needed + watsonx_model = settings.map_model(model_id) + + # Check if model exists + if watsonx_model not in AVAILABLE_MODELS: + from fastapi import HTTPException + raise HTTPException( + status_code=404, + detail={ + "error": { + "message": f"Model '{model_id}' not found", + "type": "invalid_request_error", + "code": "model_not_found", + } + }, + ) + + return ModelInfo( + id=model_id, + created=int(time.time()), + owned_by="ibm-watsonx", + ) diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..9816be7 --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1,5 @@ +"""Services for watsonx-openai-proxy.""" + +from app.services.watsonx_service import watsonx_service, WatsonxService + +__all__ = ["watsonx_service", "WatsonxService"] diff --git a/app/services/watsonx_service.py b/app/services/watsonx_service.py new file mode 100644 index 0000000..cdf9079 --- /dev/null +++ b/app/services/watsonx_service.py @@ -0,0 +1,316 @@ +"""Service for interacting with IBM watsonx.ai APIs.""" + +import asyncio +import time +from typing import AsyncIterator, Dict, List, Optional +import httpx +from app.config import settings +import logging + +logger = logging.getLogger(__name__) + + +class WatsonxService: + """Service for managing watsonx.ai API interactions.""" + + def __init__(self): + self.base_url = settings.watsonx_base_url + self.project_id = settings.watsonx_project_id + self.api_key = settings.ibm_cloud_api_key + self._bearer_token: Optional[str] = None + self._token_expiry: Optional[float] = None + self._token_lock = asyncio.Lock() + self._client: Optional[httpx.AsyncClient] = None + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create HTTP client.""" + if self._client is None: + self._client = httpx.AsyncClient(timeout=300.0) + return self._client + + async def close(self): + """Close the HTTP client.""" + if self._client: + await self._client.aclose() + self._client = None + + async def _refresh_token(self) -> str: + """Get a fresh bearer token from IBM Cloud IAM.""" + async with self._token_lock: + # Check if token is still valid + if self._bearer_token and self._token_expiry: + if time.time() < self._token_expiry - 300: # 5 min buffer + return self._bearer_token + + logger.info("Refreshing IBM Cloud bearer token...") + + client = await self._get_client() + response = await client.post( + "https://iam.cloud.ibm.com/identity/token", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}", + ) + + if response.status_code != 200: + raise Exception(f"Failed to get bearer token: {response.text}") + + data = response.json() + self._bearer_token = data["access_token"] + self._token_expiry = time.time() + data.get("expires_in", 3600) + + logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds") + return self._bearer_token + + async def _get_headers(self) -> Dict[str, str]: + """Get headers with valid bearer token.""" + token = await self._refresh_token() + return { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + async def chat_completion( + self, + model_id: str, + messages: List[Dict], + temperature: float = 1.0, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + stop: Optional[List[str]] = None, + stream: bool = False, + tools: Optional[List[Dict]] = None, + **kwargs, + ) -> Dict: + """Create a chat completion using watsonx.ai. + + Args: + model_id: The watsonx model ID + messages: List of chat messages + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + top_p: Nucleus sampling parameter + stop: Stop sequences + stream: Whether to stream the response + tools: Tool/function definitions + **kwargs: Additional parameters + + Returns: + Chat completion response + """ + headers = await self._get_headers() + client = await self._get_client() + + # Build watsonx request + payload = { + "model_id": model_id, + "project_id": self.project_id, + "messages": messages, + "parameters": { + "temperature": temperature, + "top_p": top_p, + }, + } + + if max_tokens: + payload["parameters"]["max_tokens"] = max_tokens + + if stop: + payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop] + + if tools: + payload["tools"] = tools + + url = f"{self.base_url}/text/chat" + params = {"version": "2024-02-13"} + + if stream: + params["stream"] = "true" + + response = await client.post( + url, + headers=headers, + json=payload, + params=params, + ) + + if response.status_code != 200: + raise Exception(f"watsonx API error: {response.status_code} - {response.text}") + + return response.json() + + async def chat_completion_stream( + self, + model_id: str, + messages: List[Dict], + temperature: float = 1.0, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + stop: Optional[List[str]] = None, + tools: Optional[List[Dict]] = None, + **kwargs, + ) -> AsyncIterator[Dict]: + """Stream a chat completion using watsonx.ai. + + Args: + model_id: The watsonx model ID + messages: List of chat messages + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + top_p: Nucleus sampling parameter + stop: Stop sequences + tools: Tool/function definitions + **kwargs: Additional parameters + + Yields: + Chat completion chunks + """ + headers = await self._get_headers() + client = await self._get_client() + + # Build watsonx request + payload = { + "model_id": model_id, + "project_id": self.project_id, + "messages": messages, + "parameters": { + "temperature": temperature, + "top_p": top_p, + }, + } + + if max_tokens: + payload["parameters"]["max_tokens"] = max_tokens + + if stop: + payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop] + + if tools: + payload["tools"] = tools + + url = f"{self.base_url}/text/chat_stream" + params = {"version": "2024-02-13"} + + async with client.stream( + "POST", + url, + headers=headers, + json=payload, + params=params, + ) as response: + if response.status_code != 200: + text = await response.aread() + raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}") + + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = line[6:] # Remove "data: " prefix + if data.strip() == "[DONE]": + break + try: + import json + yield json.loads(data) + except json.JSONDecodeError: + continue + + async def text_generation( + self, + model_id: str, + prompt: str, + temperature: float = 1.0, + max_tokens: Optional[int] = None, + top_p: float = 1.0, + stop: Optional[List[str]] = None, + **kwargs, + ) -> Dict: + """Generate text completion using watsonx.ai. + + Args: + model_id: The watsonx model ID + prompt: The input prompt + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + top_p: Nucleus sampling parameter + stop: Stop sequences + **kwargs: Additional parameters + + Returns: + Text generation response + """ + headers = await self._get_headers() + client = await self._get_client() + + payload = { + "model_id": model_id, + "project_id": self.project_id, + "input": prompt, + "parameters": { + "temperature": temperature, + "top_p": top_p, + }, + } + + if max_tokens: + payload["parameters"]["max_new_tokens"] = max_tokens + + if stop: + payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop] + + url = f"{self.base_url}/text/generation" + params = {"version": "2024-02-13"} + + response = await client.post( + url, + headers=headers, + json=payload, + params=params, + ) + + if response.status_code != 200: + raise Exception(f"watsonx API error: {response.status_code} - {response.text}") + + return response.json() + + async def embeddings( + self, + model_id: str, + inputs: List[str], + **kwargs, + ) -> Dict: + """Generate embeddings using watsonx.ai. + + Args: + model_id: The watsonx embedding model ID + inputs: List of texts to embed + **kwargs: Additional parameters + + Returns: + Embeddings response + """ + headers = await self._get_headers() + client = await self._get_client() + + payload = { + "model_id": model_id, + "project_id": self.project_id, + "inputs": inputs, + } + + url = f"{self.base_url}/text/embeddings" + params = {"version": "2024-02-13"} + + response = await client.post( + url, + headers=headers, + json=payload, + params=params, + ) + + if response.status_code != 200: + raise Exception(f"watsonx API error: {response.status_code} - {response.text}") + + return response.json() + + +# Global service instance +watsonx_service = WatsonxService() diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..a84683e --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1,21 @@ +"""Utility functions for watsonx-openai-proxy.""" + +from app.utils.transformers import ( + transform_messages_to_watsonx, + transform_tools_to_watsonx, + transform_watsonx_to_openai_chat, + transform_watsonx_to_openai_chat_chunk, + transform_watsonx_to_openai_completion, + transform_watsonx_to_openai_embeddings, + format_sse_event, +) + +__all__ = [ + "transform_messages_to_watsonx", + "transform_tools_to_watsonx", + "transform_watsonx_to_openai_chat", + "transform_watsonx_to_openai_chat_chunk", + "transform_watsonx_to_openai_completion", + "transform_watsonx_to_openai_embeddings", + "format_sse_event", +] diff --git a/app/utils/transformers.py b/app/utils/transformers.py new file mode 100644 index 0000000..b0dd182 --- /dev/null +++ b/app/utils/transformers.py @@ -0,0 +1,272 @@ +"""Utilities for transforming between OpenAI and watsonx formats.""" + +import time +import uuid +from typing import Any, Dict, List, Optional +from app.models.openai_models import ( + ChatMessage, + ChatCompletionChoice, + ChatCompletionUsage, + ChatCompletionResponse, + ChatCompletionChunk, + ChatCompletionChunkChoice, + ChatCompletionChunkDelta, + CompletionChoice, + CompletionResponse, + EmbeddingData, + EmbeddingUsage, + EmbeddingResponse, +) + + +def transform_messages_to_watsonx(messages: List[ChatMessage]) -> List[Dict[str, Any]]: + """Transform OpenAI messages to watsonx format. + + Args: + messages: List of OpenAI ChatMessage objects + + Returns: + List of watsonx-compatible message dicts + """ + watsonx_messages = [] + + for msg in messages: + watsonx_msg = { + "role": msg.role, + } + + if msg.content: + watsonx_msg["content"] = msg.content + + if msg.name: + watsonx_msg["name"] = msg.name + + if msg.tool_calls: + watsonx_msg["tool_calls"] = msg.tool_calls + + if msg.function_call: + watsonx_msg["function_call"] = msg.function_call + + watsonx_messages.append(watsonx_msg) + + return watsonx_messages + + +def transform_tools_to_watsonx(tools: Optional[List[Dict]]) -> Optional[List[Dict]]: + """Transform OpenAI tools to watsonx format. + + Args: + tools: List of OpenAI tool definitions + + Returns: + List of watsonx-compatible tool definitions + """ + if not tools: + return None + + # watsonx uses similar format to OpenAI for tools + return tools + + +def transform_watsonx_to_openai_chat( + watsonx_response: Dict[str, Any], + model: str, + request_id: Optional[str] = None, +) -> ChatCompletionResponse: + """Transform watsonx chat response to OpenAI format. + + Args: + watsonx_response: Response from watsonx chat API + model: Model name to include in response + request_id: Optional request ID, generates one if not provided + + Returns: + OpenAI-compatible ChatCompletionResponse + """ + response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + # Extract choices + choices = [] + watsonx_choices = watsonx_response.get("choices", []) + + for idx, choice in enumerate(watsonx_choices): + message_data = choice.get("message", {}) + + message = ChatMessage( + role=message_data.get("role", "assistant"), + content=message_data.get("content"), + tool_calls=message_data.get("tool_calls"), + function_call=message_data.get("function_call"), + ) + + choices.append( + ChatCompletionChoice( + index=idx, + message=message, + finish_reason=choice.get("finish_reason"), + ) + ) + + # Extract usage + usage_data = watsonx_response.get("usage", {}) + usage = ChatCompletionUsage( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("completion_tokens", 0), + total_tokens=usage_data.get("total_tokens", 0), + ) + + return ChatCompletionResponse( + id=response_id, + created=created, + model=model, + choices=choices, + usage=usage, + ) + + +def transform_watsonx_to_openai_chat_chunk( + watsonx_chunk: Dict[str, Any], + model: str, + request_id: str, +) -> ChatCompletionChunk: + """Transform watsonx streaming chunk to OpenAI format. + + Args: + watsonx_chunk: Streaming chunk from watsonx + model: Model name to include in response + request_id: Request ID for this stream + + Returns: + OpenAI-compatible ChatCompletionChunk + """ + created = int(time.time()) + + # Extract choices + choices = [] + watsonx_choices = watsonx_chunk.get("choices", []) + + for idx, choice in enumerate(watsonx_choices): + delta_data = choice.get("delta", {}) + + delta = ChatCompletionChunkDelta( + role=delta_data.get("role"), + content=delta_data.get("content"), + tool_calls=delta_data.get("tool_calls"), + function_call=delta_data.get("function_call"), + ) + + choices.append( + ChatCompletionChunkChoice( + index=idx, + delta=delta, + finish_reason=choice.get("finish_reason"), + ) + ) + + return ChatCompletionChunk( + id=request_id, + created=created, + model=model, + choices=choices, + ) + + +def transform_watsonx_to_openai_completion( + watsonx_response: Dict[str, Any], + model: str, + request_id: Optional[str] = None, +) -> CompletionResponse: + """Transform watsonx text generation response to OpenAI completion format. + + Args: + watsonx_response: Response from watsonx text generation API + model: Model name to include in response + request_id: Optional request ID, generates one if not provided + + Returns: + OpenAI-compatible CompletionResponse + """ + response_id = request_id or f"cmpl-{uuid.uuid4().hex[:24]}" + created = int(time.time()) + + # Extract results + results = watsonx_response.get("results", []) + choices = [] + + for idx, result in enumerate(results): + choices.append( + CompletionChoice( + text=result.get("generated_text", ""), + index=idx, + finish_reason=result.get("stop_reason"), + ) + ) + + # Extract usage + usage_data = watsonx_response.get("usage", {}) + usage = ChatCompletionUsage( + prompt_tokens=usage_data.get("prompt_tokens", 0), + completion_tokens=usage_data.get("generated_tokens", 0), + total_tokens=usage_data.get("prompt_tokens", 0) + usage_data.get("generated_tokens", 0), + ) + + return CompletionResponse( + id=response_id, + created=created, + model=model, + choices=choices, + usage=usage, + ) + + +def transform_watsonx_to_openai_embeddings( + watsonx_response: Dict[str, Any], + model: str, +) -> EmbeddingResponse: + """Transform watsonx embeddings response to OpenAI format. + + Args: + watsonx_response: Response from watsonx embeddings API + model: Model name to include in response + + Returns: + OpenAI-compatible EmbeddingResponse + """ + # Extract results + results = watsonx_response.get("results", []) + data = [] + + for idx, result in enumerate(results): + embedding = result.get("embedding", []) + data.append( + EmbeddingData( + embedding=embedding, + index=idx, + ) + ) + + # Calculate usage + input_token_count = watsonx_response.get("input_token_count", 0) + usage = EmbeddingUsage( + prompt_tokens=input_token_count, + total_tokens=input_token_count, + ) + + return EmbeddingResponse( + data=data, + model=model, + usage=usage, + ) + + +def format_sse_event(data: str) -> str: + """Format data as Server-Sent Event. + + Args: + data: JSON string to send + + Returns: + Formatted SSE string + """ + return f"data: {data}\n\n" diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..49ef418 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,27 @@ +version: '3.8' + +services: + watsonx-openai-proxy: + build: . + ports: + - "8000:8000" + environment: + - IBM_CLOUD_API_KEY=${IBM_CLOUD_API_KEY} + - WATSONX_PROJECT_ID=${WATSONX_PROJECT_ID} + - WATSONX_CLUSTER=${WATSONX_CLUSTER:-us-south} + - HOST=0.0.0.0 + - PORT=8000 + - LOG_LEVEL=${LOG_LEVEL:-info} + - API_KEY=${API_KEY:-} + - ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-*} + - MODEL_MAP_GPT4=${MODEL_MAP_GPT4:-ibm/granite-4-h-small} + - MODEL_MAP_GPT35=${MODEL_MAP_GPT35:-ibm/granite-3-8b-instruct} + - MODEL_MAP_GPT4_TURBO=${MODEL_MAP_GPT4_TURBO:-meta-llama/llama-3-3-70b-instruct} + - MODEL_MAP_TEXT_EMBEDDING_ADA_002=${MODEL_MAP_TEXT_EMBEDDING_ADA_002:-ibm/slate-125m-english-rtrvr} + restart: unless-stopped + healthcheck: + test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8000/health')"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 5s diff --git a/example_usage.py b/example_usage.py new file mode 100644 index 0000000..a594756 --- /dev/null +++ b/example_usage.py @@ -0,0 +1,183 @@ +"""Example usage of watsonx-openai-proxy with OpenAI Python SDK.""" + +import os +from openai import OpenAI + +# Configure the client to use the proxy +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key=os.getenv("API_KEY", "not-needed-if-proxy-has-no-auth"), +) + + +def example_chat_completion(): + """Example: Basic chat completion.""" + print("\n=== Chat Completion Example ===") + + response = client.chat.completions.create( + model="ibm/granite-3-8b-instruct", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ], + temperature=0.7, + max_tokens=100, + ) + + print(f"Response: {response.choices[0].message.content}") + print(f"Tokens used: {response.usage.total_tokens}") + + +def example_streaming_chat(): + """Example: Streaming chat completion.""" + print("\n=== Streaming Chat Example ===") + + stream = client.chat.completions.create( + model="ibm/granite-3-8b-instruct", + messages=[ + {"role": "user", "content": "Tell me a short story about a robot."}, + ], + stream=True, + max_tokens=200, + ) + + print("Response: ", end="", flush=True) + for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + print() + + +def example_with_model_mapping(): + """Example: Using mapped model names.""" + print("\n=== Model Mapping Example ===") + + # If you configured MODEL_MAP_GPT4=ibm/granite-4-h-small in .env + # you can use "gpt-4" and it will be mapped automatically + response = client.chat.completions.create( + model="gpt-4", # This gets mapped to ibm/granite-4-h-small + messages=[ + {"role": "user", "content": "Explain quantum computing in one sentence."}, + ], + max_tokens=50, + ) + + print(f"Response: {response.choices[0].message.content}") + + +def example_embeddings(): + """Example: Generate embeddings.""" + print("\n=== Embeddings Example ===") + + response = client.embeddings.create( + model="ibm/slate-125m-english-rtrvr", + input=[ + "The quick brown fox jumps over the lazy dog.", + "Machine learning is a subset of artificial intelligence.", + ], + ) + + print(f"Generated {len(response.data)} embeddings") + print(f"Embedding dimension: {len(response.data[0].embedding)}") + print(f"First embedding (first 5 values): {response.data[0].embedding[:5]}") + + +def example_list_models(): + """Example: List available models.""" + print("\n=== List Models Example ===") + + models = client.models.list() + + print(f"Available models: {len(models.data)}") + print("\nFirst 5 models:") + for model in models.data[:5]: + print(f" - {model.id}") + + +def example_completion_legacy(): + """Example: Legacy completion endpoint.""" + print("\n=== Legacy Completion Example ===") + + response = client.completions.create( + model="ibm/granite-3-8b-instruct", + prompt="Once upon a time, in a land far away,", + max_tokens=50, + temperature=0.8, + ) + + print(f"Completion: {response.choices[0].text}") + + +def example_with_functions(): + """Example: Function calling (if supported by model).""" + print("\n=== Function Calling Example ===") + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + } + ] + + try: + response = client.chat.completions.create( + model="ibm/granite-3-8b-instruct", + messages=[ + {"role": "user", "content": "What's the weather like in Boston?"}, + ], + tools=tools, + tool_choice="auto", + ) + + message = response.choices[0].message + if message.tool_calls: + print(f"Function called: {message.tool_calls[0].function.name}") + print(f"Arguments: {message.tool_calls[0].function.arguments}") + else: + print(f"Response: {message.content}") + except Exception as e: + print(f"Function calling may not be supported by this model: {e}") + + +if __name__ == "__main__": + print("watsonx-openai-proxy Usage Examples") + print("=" * 50) + + try: + # Run examples + example_chat_completion() + example_streaming_chat() + example_embeddings() + example_list_models() + example_completion_legacy() + + # Optional examples (may require specific configuration) + # example_with_model_mapping() + # example_with_functions() + + print("\n" + "=" * 50) + print("All examples completed successfully!") + + except Exception as e: + print(f"\nError: {e}") + print("\nMake sure:") + print("1. The proxy server is running (python -m app.main)") + print("2. Your .env file is configured correctly") + print("3. You have the OpenAI Python SDK installed (pip install openai)") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1af72dc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +fastapi==0.115.0 +uvicorn[standard]==0.32.0 +pydantic==2.9.2 +pydantic-settings==2.6.0 +httpx==0.27.2 +python-dotenv==1.0.1 +python-multipart==0.0.12 diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..5fdfc49 --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,87 @@ +"""Basic tests for watsonx-openai-proxy.""" + +import pytest +from fastapi.testclient import TestClient +from app.main import app + +client = TestClient(app) + + +def test_health_check(): + """Test the health check endpoint.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "cluster" in data + + +def test_root_endpoint(): + """Test the root endpoint.""" + response = client.get("/") + assert response.status_code == 200 + data = response.json() + assert data["service"] == "watsonx-openai-proxy" + assert "endpoints" in data + + +def test_list_models(): + """Test listing available models.""" + response = client.get("/v1/models") + assert response.status_code == 200 + data = response.json() + assert data["object"] == "list" + assert len(data["data"]) > 0 + assert all(model["object"] == "model" for model in data["data"]) + + +def test_retrieve_model(): + """Test retrieving a specific model.""" + response = client.get("/v1/models/ibm/granite-3-8b-instruct") + assert response.status_code == 200 + data = response.json() + assert data["id"] == "ibm/granite-3-8b-instruct" + assert data["object"] == "model" + + +def test_retrieve_nonexistent_model(): + """Test retrieving a model that doesn't exist.""" + response = client.get("/v1/models/nonexistent-model") + assert response.status_code == 404 + + +# Note: The following tests require valid IBM Cloud credentials +# and should be run with pytest markers or in integration tests + +@pytest.mark.skip(reason="Requires valid IBM Cloud credentials") +def test_chat_completion(): + """Test chat completion endpoint.""" + response = client.post( + "/v1/chat/completions", + json={ + "model": "ibm/granite-3-8b-instruct", + "messages": [ + {"role": "user", "content": "Hello!"} + ], + }, + ) + assert response.status_code == 200 + data = response.json() + assert "choices" in data + assert len(data["choices"]) > 0 + + +@pytest.mark.skip(reason="Requires valid IBM Cloud credentials") +def test_embeddings(): + """Test embeddings endpoint.""" + response = client.post( + "/v1/embeddings", + json={ + "model": "ibm/slate-125m-english-rtrvr", + "input": "Test text", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "data" in data + assert len(data["data"]) > 0