Add AGENTS.md documentation for AI agent guidance
This commit is contained in:
223
AGENTS.md
Normal file
223
AGENTS.md
Normal file
@@ -0,0 +1,223 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to agents when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**watsonx-openai-proxy** is an OpenAI-compatible API proxy for IBM watsonx.ai. It enables any tool or application that supports the OpenAI API format to seamlessly work with watsonx.ai models.
|
||||
|
||||
### Core Purpose
|
||||
- Provide drop-in replacement for OpenAI API endpoints
|
||||
- Translate OpenAI API requests to watsonx.ai API calls
|
||||
- Handle IBM Cloud authentication and token management automatically
|
||||
- Support streaming responses via Server-Sent Events (SSE)
|
||||
|
||||
### Technology Stack
|
||||
- **Framework**: FastAPI (async web framework)
|
||||
- **Language**: Python 3.9+
|
||||
- **HTTP Client**: httpx (async HTTP client)
|
||||
- **Validation**: Pydantic v2 (data validation and settings)
|
||||
- **Server**: uvicorn (ASGI server)
|
||||
|
||||
### Architecture
|
||||
|
||||
The codebase follows a clean, modular architecture:
|
||||
|
||||
```
|
||||
app/
|
||||
├── main.py # FastAPI app initialization, middleware, lifespan management
|
||||
├── config.py # Settings management, model mapping, environment variables
|
||||
├── routers/ # API endpoint handlers (chat, completions, embeddings, models)
|
||||
├── services/ # Business logic (watsonx_service for API interactions)
|
||||
├── models/ # Pydantic models for OpenAI-compatible schemas
|
||||
└── utils/ # Helper functions (request/response transformers)
|
||||
```
|
||||
|
||||
**Key Design Patterns**:
|
||||
- **Service Layer**: `watsonx_service.py` encapsulates all watsonx.ai API interactions
|
||||
- **Transformer Pattern**: `transformers.py` handles bidirectional conversion between OpenAI and watsonx formats
|
||||
- **Singleton Services**: Global service instances (`watsonx_service`, `settings`) for shared state
|
||||
- **Async/Await**: All I/O operations are asynchronous for better performance
|
||||
- **Middleware**: Custom authentication middleware for optional API key validation
|
||||
|
||||
## Building and Running
|
||||
|
||||
### Prerequisites
|
||||
```bash
|
||||
# Python 3.9 or higher required
|
||||
python --version
|
||||
|
||||
# IBM Cloud credentials needed:
|
||||
# - IBM_CLOUD_API_KEY
|
||||
# - WATSONX_PROJECT_ID
|
||||
```
|
||||
|
||||
### Installation
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your IBM Cloud credentials
|
||||
```
|
||||
|
||||
### Running the Server
|
||||
```bash
|
||||
# Development (with auto-reload)
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# Production (with workers)
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
|
||||
|
||||
# Using Python module
|
||||
python -m app.main
|
||||
```
|
||||
|
||||
### Docker Deployment
|
||||
```bash
|
||||
# Build image
|
||||
docker build -t watsonx-openai-proxy .
|
||||
|
||||
# Run container
|
||||
docker run -p 8000:8000 --env-file .env watsonx-openai-proxy
|
||||
|
||||
# Using docker-compose
|
||||
docker-compose up
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Install test dependencies
|
||||
pip install pytest pytest-asyncio httpx
|
||||
|
||||
# Run tests
|
||||
pytest tests/
|
||||
|
||||
# Run with coverage
|
||||
pytest tests/ --cov=app
|
||||
```
|
||||
|
||||
## Development Conventions
|
||||
|
||||
### Code Style
|
||||
- **Async First**: Use `async`/`await` for all I/O operations (HTTP requests, file operations)
|
||||
- **Type Hints**: All functions should have type annotations for parameters and return values
|
||||
- **Docstrings**: Use Google-style docstrings for functions and classes
|
||||
- **Logging**: Use the `logging` module with appropriate log levels (info, warning, error)
|
||||
|
||||
### Error Handling
|
||||
- Catch exceptions at router level and return OpenAI-compatible error responses
|
||||
- Use `HTTPException` with proper status codes and error details
|
||||
- Log errors with full context using `logger.error(..., exc_info=True)`
|
||||
- Return structured error responses matching OpenAI's error format
|
||||
|
||||
### Configuration Management
|
||||
- All configuration via environment variables (`.env` file)
|
||||
- Use `pydantic-settings` for type-safe configuration
|
||||
- Model mapping via `MODEL_MAP_*` environment variables
|
||||
- Settings accessed through global `settings` instance
|
||||
|
||||
### Token Management
|
||||
- Bearer tokens automatically refreshed every 50 minutes (expire at 60 minutes)
|
||||
- Token refresh on 401 errors from watsonx.ai
|
||||
- Thread-safe token refresh using `asyncio.Lock`
|
||||
- Initial token obtained during application startup
|
||||
|
||||
### API Compatibility
|
||||
- Maintain strict OpenAI API compatibility in request/response formats
|
||||
- Use Pydantic models from `openai_models.py` for validation
|
||||
- Transform requests/responses using functions in `transformers.py`
|
||||
- Support both streaming and non-streaming responses
|
||||
|
||||
### Adding New Endpoints
|
||||
1. Create router in `app/routers/` (e.g., `new_endpoint.py`)
|
||||
2. Define Pydantic models in `app/models/openai_models.py`
|
||||
3. Add transformation logic in `app/utils/transformers.py`
|
||||
4. Add watsonx.ai API method in `app/services/watsonx_service.py`
|
||||
5. Register router in `app/main.py` using `app.include_router()`
|
||||
|
||||
### Streaming Responses
|
||||
- Use `StreamingResponse` with `media_type="text/event-stream"`
|
||||
- Format chunks as Server-Sent Events using `format_sse_event()`
|
||||
- Always send `[DONE]` message at the end of stream
|
||||
- Handle errors gracefully and send error events in SSE format
|
||||
|
||||
### Model Mapping
|
||||
- Map OpenAI model names to watsonx models via environment variables
|
||||
- Format: `MODEL_MAP_<OPENAI_MODEL>=<WATSONX_MODEL_ID>`
|
||||
- Example: `MODEL_MAP_GPT4=ibm/granite-4-h-small`
|
||||
- Mapping applied in `settings.map_model()` before API calls
|
||||
|
||||
### Security Considerations
|
||||
- Optional API key authentication via `API_KEY` environment variable
|
||||
- Middleware validates Bearer token in Authorization header
|
||||
- IBM Cloud API key stored securely in environment variables
|
||||
- CORS configured via `ALLOWED_ORIGINS` (default: `*`)
|
||||
|
||||
### Logging Best Practices
|
||||
- Use structured logging with context (model names, request IDs)
|
||||
- Log level controlled by `LOG_LEVEL` environment variable
|
||||
- Log token refresh events at INFO level
|
||||
- Log API errors at ERROR level with full traceback
|
||||
- Include request/response details for debugging
|
||||
|
||||
### Dependencies
|
||||
- Keep `requirements.txt` minimal and pinned to specific versions
|
||||
- FastAPI and Pydantic are core dependencies - avoid breaking changes
|
||||
- httpx for async HTTP - prefer over requests/aiohttp
|
||||
- Use `uvicorn[standard]` for production-ready server
|
||||
|
||||
## Important Implementation Notes
|
||||
|
||||
### watsonx.ai API Specifics
|
||||
- Base URL format: `https://{cluster}.ml.cloud.ibm.com/ml/v1`
|
||||
- API version parameter: `version=2024-02-13` (required on all requests)
|
||||
- Chat endpoint: `/text/chat` (non-streaming) or `/text/chat_stream` (streaming)
|
||||
- Text generation: `/text/generation`
|
||||
- Embeddings: `/text/embeddings`
|
||||
|
||||
### Request/Response Transformation
|
||||
- OpenAI messages → watsonx messages: Direct mapping with role/content
|
||||
- watsonx responses → OpenAI format: Extract choices, usage, and metadata
|
||||
- Streaming chunks: Parse SSE format, transform delta objects
|
||||
- Generate unique IDs: `chatcmpl-{uuid}` for chat, `cmpl-{uuid}` for completions
|
||||
|
||||
### Common Pitfalls
|
||||
- Don't forget to refresh tokens before they expire (50-minute interval)
|
||||
- Always close httpx client on shutdown (`await watsonx_service.close()`)
|
||||
- Handle both string and list formats for `stop` parameter
|
||||
- Validate model IDs exist in watsonx.ai before making requests
|
||||
- Set appropriate timeouts for long-running generation requests (300s default)
|
||||
|
||||
### Performance Optimization
|
||||
- Reuse httpx client instance (don't create per request)
|
||||
- Use connection pooling (httpx default behavior)
|
||||
- Consider worker processes for production (`--workers 4`)
|
||||
- Monitor token refresh to avoid rate limiting
|
||||
|
||||
## Environment Variables Reference
|
||||
|
||||
### Required
|
||||
- `IBM_CLOUD_API_KEY`: IBM Cloud API key for authentication
|
||||
- `WATSONX_PROJECT_ID`: watsonx.ai project ID
|
||||
|
||||
### Optional
|
||||
- `WATSONX_CLUSTER`: Region (default: `us-south`)
|
||||
- `HOST`: Server host (default: `0.0.0.0`)
|
||||
- `PORT`: Server port (default: `8000`)
|
||||
- `LOG_LEVEL`: Logging level (default: `info`)
|
||||
- `API_KEY`: Optional proxy authentication key
|
||||
- `ALLOWED_ORIGINS`: CORS origins (default: `*`)
|
||||
- `MODEL_MAP_*`: Model name mappings
|
||||
|
||||
## API Endpoints
|
||||
|
||||
- `GET /` - API information and available endpoints
|
||||
- `GET /health` - Health check (bypasses authentication)
|
||||
- `GET /docs` - Interactive Swagger UI documentation
|
||||
- `POST /v1/chat/completions` - Chat completions (streaming supported)
|
||||
- `POST /v1/completions` - Text completions (legacy)
|
||||
- `POST /v1/embeddings` - Generate embeddings
|
||||
- `GET /v1/models` - List available models
|
||||
- `GET /v1/models/{model_id}` - Get specific model info
|
||||
20
Dockerfile
Normal file
20
Dockerfile
Normal file
@@ -0,0 +1,20 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY app ./app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:8000/health')"
|
||||
|
||||
# Run the application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
353
README.md
Normal file
353
README.md
Normal file
@@ -0,0 +1,353 @@
|
||||
# watsonx-openai-proxy
|
||||
|
||||
OpenAI-compatible API proxy for IBM watsonx.ai. This proxy allows you to use watsonx.ai models with any tool or application that supports the OpenAI API format.
|
||||
|
||||
## Features
|
||||
|
||||
- ✅ **Full OpenAI API Compatibility**: Drop-in replacement for OpenAI API
|
||||
- ✅ **Chat Completions**: `/v1/chat/completions` with streaming support
|
||||
- ✅ **Text Completions**: `/v1/completions` (legacy endpoint)
|
||||
- ✅ **Embeddings**: `/v1/embeddings` for text embeddings
|
||||
- ✅ **Model Listing**: `/v1/models` endpoint
|
||||
- ✅ **Streaming Support**: Server-Sent Events (SSE) for real-time responses
|
||||
- ✅ **Model Mapping**: Map OpenAI model names to watsonx models
|
||||
- ✅ **Automatic Token Management**: Handles IBM Cloud authentication automatically
|
||||
- ✅ **CORS Support**: Configurable cross-origin resource sharing
|
||||
- ✅ **Optional API Key Authentication**: Secure your proxy with an API key
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.9 or higher
|
||||
- IBM Cloud account with watsonx.ai access
|
||||
- IBM Cloud API key
|
||||
- watsonx.ai Project ID
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone or download this directory:
|
||||
|
||||
```bash
|
||||
cd watsonx-openai-proxy
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. Configure environment variables:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env with your credentials
|
||||
```
|
||||
|
||||
4. Run the server:
|
||||
|
||||
```bash
|
||||
python -m app.main
|
||||
```
|
||||
|
||||
Or with uvicorn:
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
The server will start at `http://localhost:8000`
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Create a `.env` file with the following variables:
|
||||
|
||||
```bash
|
||||
# Required: IBM Cloud Configuration
|
||||
IBM_CLOUD_API_KEY=your_ibm_cloud_api_key_here
|
||||
WATSONX_PROJECT_ID=your_watsonx_project_id_here
|
||||
WATSONX_CLUSTER=us-south # Options: us-south, eu-de, eu-gb, jp-tok, au-syd, ca-tor
|
||||
|
||||
# Optional: Server Configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
LOG_LEVEL=info
|
||||
|
||||
# Optional: API Key for Proxy Authentication
|
||||
API_KEY=your_optional_api_key_for_proxy_authentication
|
||||
|
||||
# Optional: CORS Configuration
|
||||
ALLOWED_ORIGINS=* # Comma-separated or * for all
|
||||
|
||||
# Optional: Model Mapping
|
||||
MODEL_MAP_GPT4=ibm/granite-4-h-small
|
||||
MODEL_MAP_GPT35=ibm/granite-3-8b-instruct
|
||||
MODEL_MAP_GPT4_TURBO=meta-llama/llama-3-3-70b-instruct
|
||||
MODEL_MAP_TEXT_EMBEDDING_ADA_002=ibm/slate-125m-english-rtrvr
|
||||
```
|
||||
|
||||
### Model Mapping
|
||||
|
||||
You can map OpenAI model names to watsonx models using environment variables:
|
||||
|
||||
```bash
|
||||
MODEL_MAP_<OPENAI_MODEL_NAME>=<WATSONX_MODEL_ID>
|
||||
```
|
||||
|
||||
For example:
|
||||
- `MODEL_MAP_GPT4=ibm/granite-4-h-small` maps `gpt-4` to `ibm/granite-4-h-small`
|
||||
- `MODEL_MAP_GPT35_TURBO=ibm/granite-3-8b-instruct` maps `gpt-3.5-turbo` to `ibm/granite-3-8b-instruct`
|
||||
|
||||
## Usage
|
||||
|
||||
### With OpenAI Python SDK
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Point to your proxy
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="your-proxy-api-key" # Optional, if you set API_KEY in .env
|
||||
)
|
||||
|
||||
# Use as normal
|
||||
response = client.chat.completions.create(
|
||||
model="ibm/granite-3-8b-instruct", # Or use mapped name like "gpt-4"
|
||||
messages=[
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
### With Streaming
|
||||
|
||||
```python
|
||||
stream = client.chat.completions.create(
|
||||
model="ibm/granite-3-8b-instruct",
|
||||
messages=[{"role": "user", "content": "Tell me a story"}],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content, end="")
|
||||
```
|
||||
|
||||
### With cURL
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer your-proxy-api-key" \
|
||||
-d '{
|
||||
"model": "ibm/granite-3-8b-instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```python
|
||||
response = client.embeddings.create(
|
||||
model="ibm/slate-125m-english-rtrvr",
|
||||
input="Your text to embed"
|
||||
)
|
||||
|
||||
print(response.data[0].embedding)
|
||||
```
|
||||
|
||||
## Available Endpoints
|
||||
|
||||
- `GET /` - API information
|
||||
- `GET /health` - Health check
|
||||
- `GET /docs` - Interactive API documentation (Swagger UI)
|
||||
- `POST /v1/chat/completions` - Chat completions
|
||||
- `POST /v1/completions` - Text completions (legacy)
|
||||
- `POST /v1/embeddings` - Generate embeddings
|
||||
- `GET /v1/models` - List available models
|
||||
- `GET /v1/models/{model_id}` - Get model information
|
||||
|
||||
## Supported Models
|
||||
|
||||
The proxy supports all watsonx.ai models available in your project, including:
|
||||
|
||||
### Chat Models
|
||||
- IBM Granite models (3.x, 4.x series)
|
||||
- Meta Llama models (3.x, 4.x series)
|
||||
- Mistral models
|
||||
- Other models available on watsonx.ai
|
||||
|
||||
### Embedding Models
|
||||
- `ibm/slate-125m-english-rtrvr`
|
||||
- `ibm/slate-30m-english-rtrvr`
|
||||
|
||||
See `/v1/models` endpoint for the complete list.
|
||||
|
||||
## Authentication
|
||||
|
||||
### Proxy Authentication (Optional)
|
||||
|
||||
If you set `API_KEY` in your `.env` file, clients must provide it:
|
||||
|
||||
```python
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="your-proxy-api-key"
|
||||
)
|
||||
```
|
||||
|
||||
### IBM Cloud Authentication
|
||||
|
||||
The proxy handles IBM Cloud authentication automatically using your `IBM_CLOUD_API_KEY`. Bearer tokens are:
|
||||
- Automatically obtained on startup
|
||||
- Refreshed every 50 minutes (tokens expire after 60 minutes)
|
||||
- Refreshed on 401 errors
|
||||
|
||||
## Deployment
|
||||
|
||||
### Docker (Recommended)
|
||||
|
||||
Create a `Dockerfile`:
|
||||
|
||||
```dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY app ./app
|
||||
COPY .env .
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
```
|
||||
|
||||
Build and run:
|
||||
|
||||
```bash
|
||||
docker build -t watsonx-openai-proxy .
|
||||
docker run -p 8000:8000 --env-file .env watsonx-openai-proxy
|
||||
```
|
||||
|
||||
### Production Deployment
|
||||
|
||||
For production, consider:
|
||||
|
||||
1. **Use a production ASGI server**: The included uvicorn is suitable, but configure workers:
|
||||
```bash
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
|
||||
```
|
||||
|
||||
2. **Set up HTTPS**: Use a reverse proxy like nginx or Caddy
|
||||
|
||||
3. **Configure CORS**: Set `ALLOWED_ORIGINS` to specific domains
|
||||
|
||||
4. **Enable API key authentication**: Set `API_KEY` in environment
|
||||
|
||||
5. **Monitor logs**: Set `LOG_LEVEL=info` or `warning` in production
|
||||
|
||||
6. **Use environment secrets**: Don't commit `.env` file, use secret management
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### 401 Unauthorized
|
||||
|
||||
- Check that `IBM_CLOUD_API_KEY` is valid
|
||||
- Verify your IBM Cloud account has watsonx.ai access
|
||||
- Check server logs for token refresh errors
|
||||
|
||||
### Model Not Found
|
||||
|
||||
- Verify the model ID exists in watsonx.ai
|
||||
- Check that your project has access to the model
|
||||
- Use `/v1/models` endpoint to see available models
|
||||
|
||||
### Connection Errors
|
||||
|
||||
- Verify `WATSONX_CLUSTER` matches your project's region
|
||||
- Check firewall/network settings
|
||||
- Ensure watsonx.ai services are accessible
|
||||
|
||||
### Streaming Issues
|
||||
|
||||
- Some models may not support streaming
|
||||
- Check client library supports SSE (Server-Sent Events)
|
||||
- Verify network doesn't buffer streaming responses
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Install dev dependencies
|
||||
pip install pytest pytest-asyncio httpx
|
||||
|
||||
# Run tests
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
### Code Structure
|
||||
|
||||
```
|
||||
watsonx-openai-proxy/
|
||||
├── app/
|
||||
│ ├── main.py # FastAPI application
|
||||
│ ├── config.py # Configuration management
|
||||
│ ├── routers/ # API endpoint routers
|
||||
│ │ ├── chat.py # Chat completions
|
||||
│ │ ├── completions.py # Text completions
|
||||
│ │ ├── embeddings.py # Embeddings
|
||||
│ │ └── models.py # Model listing
|
||||
│ ├── services/ # Business logic
|
||||
│ │ └── watsonx_service.py # watsonx.ai API client
|
||||
│ ├── models/ # Pydantic models
|
||||
│ │ └── openai_models.py # OpenAI-compatible schemas
|
||||
│ └── utils/ # Utilities
|
||||
│ └── transformers.py # Request/response transformers
|
||||
├── tests/ # Test files
|
||||
├── requirements.txt # Python dependencies
|
||||
├── .env.example # Environment template
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please:
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Add tests if applicable
|
||||
5. Submit a pull request
|
||||
|
||||
## License
|
||||
|
||||
Apache 2.0 License - See LICENSE file for details.
|
||||
|
||||
## Related Projects
|
||||
|
||||
- [watsonx-unofficial-aisdk-provider](../wxai-provider/) - Vercel AI SDK provider for watsonx.ai
|
||||
- [OpenCode watsonx plugin](../.opencode/plugins/) - Token management plugin for OpenCode
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is **not an official IBM product**. It's a community-maintained proxy for integrating watsonx.ai with OpenAI-compatible tools. watsonx.ai is a trademark of IBM.
|
||||
|
||||
## Support
|
||||
|
||||
For issues and questions:
|
||||
- Check the [Troubleshooting](#troubleshooting) section
|
||||
- Review server logs (`LOG_LEVEL=debug` for detailed logs)
|
||||
- Open an issue in the repository
|
||||
- Consult [IBM watsonx.ai documentation](https://www.ibm.com/docs/en/watsonx-as-a-service)
|
||||
3
app/__init__.py
Normal file
3
app/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""watsonx-openai-proxy application package."""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
91
app/config.py
Normal file
91
app/config.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Configuration management for watsonx-openai-proxy."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# IBM Cloud Configuration
|
||||
ibm_cloud_api_key: str
|
||||
watsonx_project_id: str
|
||||
watsonx_cluster: str = "us-south"
|
||||
|
||||
# Server Configuration
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
log_level: str = "info"
|
||||
|
||||
# API Configuration
|
||||
api_key: Optional[str] = None
|
||||
allowed_origins: str = "*"
|
||||
|
||||
# Token Management
|
||||
token_refresh_interval: int = 3000 # 50 minutes in seconds
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="allow", # Allow extra fields for model mapping
|
||||
)
|
||||
|
||||
@property
|
||||
def watsonx_base_url(self) -> str:
|
||||
"""Construct the watsonx.ai base URL from cluster."""
|
||||
return f"https://{self.watsonx_cluster}.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
"""Parse CORS origins from comma-separated string."""
|
||||
if self.allowed_origins == "*":
|
||||
return ["*"]
|
||||
return [origin.strip() for origin in self.allowed_origins.split(",")]
|
||||
|
||||
def get_model_mapping(self) -> Dict[str, str]:
|
||||
"""Extract model mappings from environment variables.
|
||||
|
||||
Looks for variables like MODEL_MAP_GPT4=ibm/granite-4-h-small
|
||||
and creates a mapping dict.
|
||||
"""
|
||||
import re
|
||||
mapping = {}
|
||||
|
||||
# Check os.environ first
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith("MODEL_MAP_"):
|
||||
model_name = key.replace("MODEL_MAP_", "")
|
||||
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
||||
openai_model = model_name.lower().replace("_", "-")
|
||||
mapping[openai_model] = value
|
||||
|
||||
# Also check pydantic's extra fields (from .env file)
|
||||
# These come in as lowercase: model_map_gpt4 instead of MODEL_MAP_GPT4
|
||||
extra = getattr(self, '__pydantic_extra__', {}) or {}
|
||||
for key, value in extra.items():
|
||||
if key.startswith("model_map_"):
|
||||
# Convert back to uppercase for processing
|
||||
model_name = key.replace("model_map_", "").upper()
|
||||
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
||||
openai_model = model_name.lower().replace("_", "-")
|
||||
mapping[openai_model] = value
|
||||
|
||||
return mapping
|
||||
|
||||
def map_model(self, openai_model: str) -> str:
|
||||
"""Map an OpenAI model name to a watsonx model ID.
|
||||
|
||||
Args:
|
||||
openai_model: OpenAI model name (e.g., "gpt-4", "gpt-3.5-turbo")
|
||||
|
||||
Returns:
|
||||
Corresponding watsonx model ID, or the original name if no mapping exists
|
||||
"""
|
||||
model_map = self.get_model_mapping()
|
||||
return model_map.get(openai_model, openai_model)
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
157
app/main.py
Normal file
157
app/main.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Main FastAPI application for watsonx-openai-proxy."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.config import settings
|
||||
from app.routers import chat, completions, embeddings, models
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan events."""
|
||||
# Startup
|
||||
logger.info("Starting watsonx-openai-proxy...")
|
||||
logger.info(f"Cluster: {settings.watsonx_cluster}")
|
||||
logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...")
|
||||
|
||||
# Initialize token
|
||||
try:
|
||||
await watsonx_service._refresh_token()
|
||||
logger.info("Initial bearer token obtained successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to obtain initial bearer token: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down watsonx-openai-proxy...")
|
||||
await watsonx_service.close()
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="watsonx-openai-proxy",
|
||||
description="OpenAI-compatible API proxy for IBM watsonx.ai",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Optional API key authentication middleware
|
||||
@app.middleware("http")
|
||||
async def authenticate(request: Request, call_next):
|
||||
"""Authenticate requests if API key is configured."""
|
||||
if settings.api_key:
|
||||
# Skip authentication for health check
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Missing Authorization header",
|
||||
"type": "authentication_error",
|
||||
"code": "missing_authorization",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Validate API key
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Invalid Authorization header format",
|
||||
"type": "authentication_error",
|
||||
"code": "invalid_authorization_format",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
if token != settings.api_key:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Invalid API key",
|
||||
"type": "authentication_error",
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(chat.router, tags=["Chat"])
|
||||
app.include_router(completions.router, tags=["Completions"])
|
||||
app.include_router(embeddings.router, tags=["Embeddings"])
|
||||
app.include_router(models.router, tags=["Models"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "watsonx-openai-proxy",
|
||||
"cluster": settings.watsonx_cluster,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"service": "watsonx-openai-proxy",
|
||||
"description": "OpenAI-compatible API proxy for IBM watsonx.ai",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"chat": "/v1/chat/completions",
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
"models": "/v1/models",
|
||||
"health": "/health",
|
||||
},
|
||||
"documentation": "/docs",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
log_level=settings.log_level,
|
||||
reload=False,
|
||||
)
|
||||
31
app/models/__init__.py
Normal file
31
app/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""OpenAI-compatible data models."""
|
||||
|
||||
from app.models.openai_models import (
|
||||
ChatMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelsResponse,
|
||||
ModelInfo,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatCompletionResponse",
|
||||
"ChatCompletionChunk",
|
||||
"CompletionRequest",
|
||||
"CompletionResponse",
|
||||
"EmbeddingRequest",
|
||||
"EmbeddingResponse",
|
||||
"ModelsResponse",
|
||||
"ModelInfo",
|
||||
"ErrorResponse",
|
||||
"ErrorDetail",
|
||||
]
|
||||
213
app/models/openai_models.py
Normal file
213
app/models/openai_models.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""OpenAI-compatible request and response models."""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Chat Completions Models
|
||||
# ============================================================================
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A chat message in the conversation."""
|
||||
role: Literal["system", "user", "assistant", "function", "tool"]
|
||||
content: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
function_call: Optional[Dict[str, Any]] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
"""Function call specification."""
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call specification."""
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""OpenAI chat completion request."""
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
|
||||
n: Optional[int] = Field(default=1, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = Field(default=None, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
functions: Optional[List[Dict[str, Any]]] = None
|
||||
function_call: Optional[Union[str, Dict[str, str]]] = None
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
"""A single chat completion choice."""
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionUsage(BaseModel):
|
||||
"""Token usage information."""
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
"""OpenAI chat completion response."""
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionChoice]
|
||||
usage: ChatCompletionUsage
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunkDelta(BaseModel):
|
||||
"""Delta content in streaming response."""
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[Dict[str, Any]] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class ChatCompletionChunkChoice(BaseModel):
|
||||
"""A single streaming chunk choice."""
|
||||
index: int
|
||||
delta: ChatCompletionChunkDelta
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
"""OpenAI streaming chat completion chunk."""
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionChunkChoice]
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Completions Models (Legacy)
|
||||
# ============================================================================
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""OpenAI completion request (legacy)."""
|
||||
model: str
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]]
|
||||
suffix: Optional[str] = None
|
||||
max_tokens: Optional[int] = Field(default=16, ge=1)
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
|
||||
n: Optional[int] = Field(default=1, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = Field(default=None, ge=0, le=5)
|
||||
echo: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
best_of: Optional[int] = Field(default=1, ge=1)
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionChoice(BaseModel):
|
||||
"""A single completion choice."""
|
||||
text: str
|
||||
index: int
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""OpenAI completion response."""
|
||||
id: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: ChatCompletionUsage
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Embeddings Models
|
||||
# ============================================================================
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""OpenAI embedding request."""
|
||||
model: str
|
||||
input: Union[str, List[str], List[int], List[List[int]]]
|
||||
encoding_format: Optional[Literal["float", "base64"]] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingData(BaseModel):
|
||||
"""A single embedding result."""
|
||||
object: Literal["embedding"] = "embedding"
|
||||
embedding: List[float]
|
||||
index: int
|
||||
|
||||
|
||||
class EmbeddingUsage(BaseModel):
|
||||
"""Token usage for embeddings."""
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""OpenAI embedding response."""
|
||||
object: Literal["list"] = "list"
|
||||
data: List[EmbeddingData]
|
||||
model: str
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models List
|
||||
# ============================================================================
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about a single model."""
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
"""List of available models."""
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelInfo]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error Models
|
||||
# ============================================================================
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Error detail information."""
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OpenAI error response."""
|
||||
error: ErrorDetail
|
||||
5
app/routers/__init__.py
Normal file
5
app/routers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""API routers for watsonx-openai-proxy."""
|
||||
|
||||
from app.routers import chat, completions, embeddings, models
|
||||
|
||||
__all__ = ["chat", "completions", "embeddings", "models"]
|
||||
156
app/routers/chat.py
Normal file
156
app/routers/chat.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Chat completions endpoint router."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Union
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from app.models.openai_models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
from app.utils.transformers import (
|
||||
transform_messages_to_watsonx,
|
||||
transform_tools_to_watsonx,
|
||||
transform_watsonx_to_openai_chat,
|
||||
transform_watsonx_to_openai_chat_chunk,
|
||||
format_sse_event,
|
||||
)
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=Union[ChatCompletionResponse, ErrorResponse],
|
||||
responses={
|
||||
200: {"model": ChatCompletionResponse},
|
||||
400: {"model": ErrorResponse},
|
||||
401: {"model": ErrorResponse},
|
||||
500: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
http_request: Request,
|
||||
):
|
||||
"""Create a chat completion using OpenAI-compatible API.
|
||||
|
||||
This endpoint accepts OpenAI-formatted requests and translates them
|
||||
to watsonx.ai API calls.
|
||||
"""
|
||||
try:
|
||||
# Map model name if needed
|
||||
watsonx_model = settings.map_model(request.model)
|
||||
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
|
||||
|
||||
# Transform messages
|
||||
watsonx_messages = transform_messages_to_watsonx(request.messages)
|
||||
|
||||
# Transform tools if present
|
||||
watsonx_tools = transform_tools_to_watsonx(request.tools)
|
||||
|
||||
# Handle streaming
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
stream_chat_completion(
|
||||
watsonx_model,
|
||||
watsonx_messages,
|
||||
request,
|
||||
watsonx_tools,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
watsonx_response = await watsonx_service.chat_completion(
|
||||
model_id=watsonx_model,
|
||||
messages=watsonx_messages,
|
||||
temperature=request.temperature or 1.0,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p or 1.0,
|
||||
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||
tools=watsonx_tools,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
openai_response = transform_watsonx_to_openai_chat(
|
||||
watsonx_response,
|
||||
request.model,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "internal_error",
|
||||
"code": "internal_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
watsonx_model: str,
|
||||
watsonx_messages: list,
|
||||
request: ChatCompletionRequest,
|
||||
watsonx_tools: list = None,
|
||||
):
|
||||
"""Stream chat completion responses.
|
||||
|
||||
Args:
|
||||
watsonx_model: The watsonx model ID
|
||||
watsonx_messages: Transformed messages
|
||||
request: Original OpenAI request
|
||||
watsonx_tools: Transformed tools
|
||||
|
||||
Yields:
|
||||
Server-Sent Events with chat completion chunks
|
||||
"""
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
|
||||
try:
|
||||
async for chunk in watsonx_service.chat_completion_stream(
|
||||
model_id=watsonx_model,
|
||||
messages=watsonx_messages,
|
||||
temperature=request.temperature or 1.0,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p or 1.0,
|
||||
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||
tools=watsonx_tools,
|
||||
):
|
||||
# Transform chunk to OpenAI format
|
||||
openai_chunk = transform_watsonx_to_openai_chat_chunk(
|
||||
chunk,
|
||||
request.model,
|
||||
request_id,
|
||||
)
|
||||
|
||||
# Send as SSE
|
||||
yield format_sse_event(openai_chunk.model_dump_json())
|
||||
|
||||
# Send [DONE] message
|
||||
yield format_sse_event("[DONE]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorDetail(
|
||||
message=str(e),
|
||||
type="internal_error",
|
||||
code="stream_error",
|
||||
)
|
||||
)
|
||||
yield format_sse_event(error_response.model_dump_json())
|
||||
109
app/routers/completions.py
Normal file
109
app/routers/completions.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Text completions endpoint router (legacy)."""
|
||||
|
||||
import uuid
|
||||
from typing import Union
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from app.models.openai_models import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
from app.utils.transformers import transform_watsonx_to_openai_completion
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
response_model=Union[CompletionResponse, ErrorResponse],
|
||||
responses={
|
||||
200: {"model": CompletionResponse},
|
||||
400: {"model": ErrorResponse},
|
||||
401: {"model": ErrorResponse},
|
||||
500: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def create_completion(
|
||||
request: CompletionRequest,
|
||||
http_request: Request,
|
||||
):
|
||||
"""Create a text completion using OpenAI-compatible API (legacy).
|
||||
|
||||
This endpoint accepts OpenAI-formatted completion requests and translates
|
||||
them to watsonx.ai text generation API calls.
|
||||
"""
|
||||
try:
|
||||
# Map model name if needed
|
||||
watsonx_model = settings.map_model(request.model)
|
||||
logger.info(f"Completion request: {request.model} -> {watsonx_model}")
|
||||
|
||||
# Handle prompt (can be string or list)
|
||||
if isinstance(request.prompt, list):
|
||||
if len(request.prompt) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Prompt cannot be empty",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_prompt",
|
||||
}
|
||||
},
|
||||
)
|
||||
# For now, just use the first prompt
|
||||
# TODO: Handle multiple prompts with n parameter
|
||||
prompt = request.prompt[0] if isinstance(request.prompt[0], str) else ""
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
# Note: Streaming not implemented for completions yet
|
||||
if request.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Streaming not supported for completions endpoint",
|
||||
"type": "invalid_request_error",
|
||||
"code": "streaming_not_supported",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Call watsonx text generation
|
||||
watsonx_response = await watsonx_service.text_generation(
|
||||
model_id=watsonx_model,
|
||||
prompt=prompt,
|
||||
temperature=request.temperature or 1.0,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p or 1.0,
|
||||
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
openai_response = transform_watsonx_to_openai_completion(
|
||||
watsonx_response,
|
||||
request.model,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in completion: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "internal_error",
|
||||
"code": "internal_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
114
app/routers/embeddings.py
Normal file
114
app/routers/embeddings.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Embeddings endpoint router."""
|
||||
|
||||
from typing import Union
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from app.models.openai_models import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
from app.utils.transformers import transform_watsonx_to_openai_embeddings
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
response_model=Union[EmbeddingResponse, ErrorResponse],
|
||||
responses={
|
||||
200: {"model": EmbeddingResponse},
|
||||
400: {"model": ErrorResponse},
|
||||
401: {"model": ErrorResponse},
|
||||
500: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def create_embeddings(
|
||||
request: EmbeddingRequest,
|
||||
http_request: Request,
|
||||
):
|
||||
"""Create embeddings using OpenAI-compatible API.
|
||||
|
||||
This endpoint accepts OpenAI-formatted embedding requests and translates
|
||||
them to watsonx.ai embeddings API calls.
|
||||
"""
|
||||
try:
|
||||
# Map model name if needed
|
||||
watsonx_model = settings.map_model(request.model)
|
||||
logger.info(f"Embeddings request: {request.model} -> {watsonx_model}")
|
||||
|
||||
# Handle input (can be string or list)
|
||||
if isinstance(request.input, str):
|
||||
inputs = [request.input]
|
||||
elif isinstance(request.input, list):
|
||||
if len(request.input) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Input cannot be empty",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_input",
|
||||
}
|
||||
},
|
||||
)
|
||||
# Handle list of strings or list of token IDs
|
||||
if isinstance(request.input[0], str):
|
||||
inputs = request.input
|
||||
else:
|
||||
# Token IDs not supported yet
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Token ID input not supported",
|
||||
"type": "invalid_request_error",
|
||||
"code": "unsupported_input_type",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Invalid input type",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_input_type",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Call watsonx embeddings
|
||||
watsonx_response = await watsonx_service.embeddings(
|
||||
model_id=watsonx_model,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
openai_response = transform_watsonx_to_openai_embeddings(
|
||||
watsonx_response,
|
||||
request.model,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in embeddings: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "internal_error",
|
||||
"code": "internal_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
120
app/routers/models.py
Normal file
120
app/routers/models.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Models endpoint router."""
|
||||
|
||||
import time
|
||||
from fastapi import APIRouter
|
||||
from app.models.openai_models import ModelsResponse, ModelInfo
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Predefined list of available models
|
||||
# This can be extended or made dynamic based on watsonx.ai API
|
||||
AVAILABLE_MODELS = [
|
||||
# Granite Models
|
||||
"ibm/granite-3-1-8b-base",
|
||||
"ibm/granite-3-2-8b-instruct",
|
||||
"ibm/granite-3-3-8b-instruct",
|
||||
"ibm/granite-3-8b-instruct",
|
||||
"ibm/granite-4-h-small",
|
||||
"ibm/granite-8b-code-instruct",
|
||||
|
||||
# Llama Models
|
||||
"meta-llama/llama-3-1-70b-gptq",
|
||||
"meta-llama/llama-3-1-8b",
|
||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||
"meta-llama/llama-3-3-70b-instruct",
|
||||
"meta-llama/llama-3-405b-instruct",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
|
||||
|
||||
# Mistral Models
|
||||
"mistral-large-2512",
|
||||
"mistralai/mistral-medium-2505",
|
||||
"mistralai/mistral-small-3-1-24b-instruct-2503",
|
||||
|
||||
# Other Models
|
||||
"openai/gpt-oss-120b",
|
||||
|
||||
# Embedding Models
|
||||
"ibm/slate-125m-english-rtrvr",
|
||||
"ibm/slate-30m-english-rtrvr",
|
||||
]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/models",
|
||||
response_model=ModelsResponse,
|
||||
)
|
||||
async def list_models():
|
||||
"""List available models in OpenAI-compatible format.
|
||||
|
||||
Returns a list of models that can be used with the API.
|
||||
Includes both the actual watsonx model IDs and any mapped names.
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
models = []
|
||||
|
||||
# Add all available watsonx models
|
||||
for model_id in AVAILABLE_MODELS:
|
||||
models.append(
|
||||
ModelInfo(
|
||||
id=model_id,
|
||||
created=created_time,
|
||||
owned_by="ibm-watsonx",
|
||||
)
|
||||
)
|
||||
|
||||
# Add mapped model names (e.g., gpt-4 -> ibm/granite-4-h-small)
|
||||
model_mapping = settings.get_model_mapping()
|
||||
for openai_name, watsonx_id in model_mapping.items():
|
||||
if watsonx_id in AVAILABLE_MODELS:
|
||||
models.append(
|
||||
ModelInfo(
|
||||
id=openai_name,
|
||||
created=created_time,
|
||||
owned_by="ibm-watsonx",
|
||||
)
|
||||
)
|
||||
|
||||
return ModelsResponse(data=models)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/models/{model_id}",
|
||||
response_model=ModelInfo,
|
||||
)
|
||||
async def retrieve_model(model_id: str):
|
||||
"""Retrieve information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: The model ID to retrieve
|
||||
|
||||
Returns:
|
||||
Model information
|
||||
"""
|
||||
# Map the model if needed
|
||||
watsonx_model = settings.map_model(model_id)
|
||||
|
||||
# Check if model exists
|
||||
if watsonx_model not in AVAILABLE_MODELS:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": {
|
||||
"message": f"Model '{model_id}' not found",
|
||||
"type": "invalid_request_error",
|
||||
"code": "model_not_found",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return ModelInfo(
|
||||
id=model_id,
|
||||
created=int(time.time()),
|
||||
owned_by="ibm-watsonx",
|
||||
)
|
||||
5
app/services/__init__.py
Normal file
5
app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Services for watsonx-openai-proxy."""
|
||||
|
||||
from app.services.watsonx_service import watsonx_service, WatsonxService
|
||||
|
||||
__all__ = ["watsonx_service", "WatsonxService"]
|
||||
316
app/services/watsonx_service.py
Normal file
316
app/services/watsonx_service.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Service for interacting with IBM watsonx.ai APIs."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import AsyncIterator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatsonxService:
|
||||
"""Service for managing watsonx.ai API interactions."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.watsonx_base_url
|
||||
self.project_id = settings.watsonx_project_id
|
||||
self.api_key = settings.ibm_cloud_api_key
|
||||
self._bearer_token: Optional[str] = None
|
||||
self._token_expiry: Optional[float] = None
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=300.0)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _refresh_token(self) -> str:
|
||||
"""Get a fresh bearer token from IBM Cloud IAM."""
|
||||
async with self._token_lock:
|
||||
# Check if token is still valid
|
||||
if self._bearer_token and self._token_expiry:
|
||||
if time.time() < self._token_expiry - 300: # 5 min buffer
|
||||
return self._bearer_token
|
||||
|
||||
logger.info("Refreshing IBM Cloud bearer token...")
|
||||
|
||||
client = await self._get_client()
|
||||
response = await client.post(
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}",
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get bearer token: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
self._bearer_token = data["access_token"]
|
||||
self._token_expiry = time.time() + data.get("expires_in", 3600)
|
||||
|
||||
logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds")
|
||||
return self._bearer_token
|
||||
|
||||
async def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers with valid bearer token."""
|
||||
token = await self._refresh_token()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Dict],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = False,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Create a chat completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
messages: List of chat messages
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
stream: Whether to stream the response
|
||||
tools: Tool/function definitions
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
# Build watsonx request
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"messages": messages,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
url = f"{self.base_url}/text/chat"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
if stream:
|
||||
params["stream"] = "true"
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Dict],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict]:
|
||||
"""Stream a chat completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
messages: List of chat messages
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
tools: Tool/function definitions
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Yields:
|
||||
Chat completion chunks
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
# Build watsonx request
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"messages": messages,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
url = f"{self.base_url}/text/chat_stream"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
import json
|
||||
yield json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def text_generation(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Generate text completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
prompt: The input prompt
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Text generation response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"input": prompt,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_new_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
url = f"{self.base_url}/text/generation"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
inputs: List[str],
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Generate embeddings using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx embedding model ID
|
||||
inputs: List of texts to embed
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Embeddings response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/text/embeddings"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
# Global service instance
|
||||
watsonx_service = WatsonxService()
|
||||
21
app/utils/__init__.py
Normal file
21
app/utils/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Utility functions for watsonx-openai-proxy."""
|
||||
|
||||
from app.utils.transformers import (
|
||||
transform_messages_to_watsonx,
|
||||
transform_tools_to_watsonx,
|
||||
transform_watsonx_to_openai_chat,
|
||||
transform_watsonx_to_openai_chat_chunk,
|
||||
transform_watsonx_to_openai_completion,
|
||||
transform_watsonx_to_openai_embeddings,
|
||||
format_sse_event,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"transform_messages_to_watsonx",
|
||||
"transform_tools_to_watsonx",
|
||||
"transform_watsonx_to_openai_chat",
|
||||
"transform_watsonx_to_openai_chat_chunk",
|
||||
"transform_watsonx_to_openai_completion",
|
||||
"transform_watsonx_to_openai_embeddings",
|
||||
"format_sse_event",
|
||||
]
|
||||
272
app/utils/transformers.py
Normal file
272
app/utils/transformers.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Utilities for transforming between OpenAI and watsonx formats."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from app.models.openai_models import (
|
||||
ChatMessage,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionUsage,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionChunkChoice,
|
||||
ChatCompletionChunkDelta,
|
||||
CompletionChoice,
|
||||
CompletionResponse,
|
||||
EmbeddingData,
|
||||
EmbeddingUsage,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
|
||||
|
||||
def transform_messages_to_watsonx(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||
"""Transform OpenAI messages to watsonx format.
|
||||
|
||||
Args:
|
||||
messages: List of OpenAI ChatMessage objects
|
||||
|
||||
Returns:
|
||||
List of watsonx-compatible message dicts
|
||||
"""
|
||||
watsonx_messages = []
|
||||
|
||||
for msg in messages:
|
||||
watsonx_msg = {
|
||||
"role": msg.role,
|
||||
}
|
||||
|
||||
if msg.content:
|
||||
watsonx_msg["content"] = msg.content
|
||||
|
||||
if msg.name:
|
||||
watsonx_msg["name"] = msg.name
|
||||
|
||||
if msg.tool_calls:
|
||||
watsonx_msg["tool_calls"] = msg.tool_calls
|
||||
|
||||
if msg.function_call:
|
||||
watsonx_msg["function_call"] = msg.function_call
|
||||
|
||||
watsonx_messages.append(watsonx_msg)
|
||||
|
||||
return watsonx_messages
|
||||
|
||||
|
||||
def transform_tools_to_watsonx(tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
|
||||
"""Transform OpenAI tools to watsonx format.
|
||||
|
||||
Args:
|
||||
tools: List of OpenAI tool definitions
|
||||
|
||||
Returns:
|
||||
List of watsonx-compatible tool definitions
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
# watsonx uses similar format to OpenAI for tools
|
||||
return tools
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_chat(
|
||||
watsonx_response: Dict[str, Any],
|
||||
model: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Transform watsonx chat response to OpenAI format.
|
||||
|
||||
Args:
|
||||
watsonx_response: Response from watsonx chat API
|
||||
model: Model name to include in response
|
||||
request_id: Optional request ID, generates one if not provided
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible ChatCompletionResponse
|
||||
"""
|
||||
response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
|
||||
# Extract choices
|
||||
choices = []
|
||||
watsonx_choices = watsonx_response.get("choices", [])
|
||||
|
||||
for idx, choice in enumerate(watsonx_choices):
|
||||
message_data = choice.get("message", {})
|
||||
|
||||
message = ChatMessage(
|
||||
role=message_data.get("role", "assistant"),
|
||||
content=message_data.get("content"),
|
||||
tool_calls=message_data.get("tool_calls"),
|
||||
function_call=message_data.get("function_call"),
|
||||
)
|
||||
|
||||
choices.append(
|
||||
ChatCompletionChoice(
|
||||
index=idx,
|
||||
message=message,
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
)
|
||||
|
||||
# Extract usage
|
||||
usage_data = watsonx_response.get("usage", {})
|
||||
usage = ChatCompletionUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_chat_chunk(
|
||||
watsonx_chunk: Dict[str, Any],
|
||||
model: str,
|
||||
request_id: str,
|
||||
) -> ChatCompletionChunk:
|
||||
"""Transform watsonx streaming chunk to OpenAI format.
|
||||
|
||||
Args:
|
||||
watsonx_chunk: Streaming chunk from watsonx
|
||||
model: Model name to include in response
|
||||
request_id: Request ID for this stream
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible ChatCompletionChunk
|
||||
"""
|
||||
created = int(time.time())
|
||||
|
||||
# Extract choices
|
||||
choices = []
|
||||
watsonx_choices = watsonx_chunk.get("choices", [])
|
||||
|
||||
for idx, choice in enumerate(watsonx_choices):
|
||||
delta_data = choice.get("delta", {})
|
||||
|
||||
delta = ChatCompletionChunkDelta(
|
||||
role=delta_data.get("role"),
|
||||
content=delta_data.get("content"),
|
||||
tool_calls=delta_data.get("tool_calls"),
|
||||
function_call=delta_data.get("function_call"),
|
||||
)
|
||||
|
||||
choices.append(
|
||||
ChatCompletionChunkChoice(
|
||||
index=idx,
|
||||
delta=delta,
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model,
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_completion(
|
||||
watsonx_response: Dict[str, Any],
|
||||
model: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> CompletionResponse:
|
||||
"""Transform watsonx text generation response to OpenAI completion format.
|
||||
|
||||
Args:
|
||||
watsonx_response: Response from watsonx text generation API
|
||||
model: Model name to include in response
|
||||
request_id: Optional request ID, generates one if not provided
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible CompletionResponse
|
||||
"""
|
||||
response_id = request_id or f"cmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
|
||||
# Extract results
|
||||
results = watsonx_response.get("results", [])
|
||||
choices = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
choices.append(
|
||||
CompletionChoice(
|
||||
text=result.get("generated_text", ""),
|
||||
index=idx,
|
||||
finish_reason=result.get("stop_reason"),
|
||||
)
|
||||
)
|
||||
|
||||
# Extract usage
|
||||
usage_data = watsonx_response.get("usage", {})
|
||||
usage = ChatCompletionUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("generated_tokens", 0),
|
||||
total_tokens=usage_data.get("prompt_tokens", 0) + usage_data.get("generated_tokens", 0),
|
||||
)
|
||||
|
||||
return CompletionResponse(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_embeddings(
|
||||
watsonx_response: Dict[str, Any],
|
||||
model: str,
|
||||
) -> EmbeddingResponse:
|
||||
"""Transform watsonx embeddings response to OpenAI format.
|
||||
|
||||
Args:
|
||||
watsonx_response: Response from watsonx embeddings API
|
||||
model: Model name to include in response
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible EmbeddingResponse
|
||||
"""
|
||||
# Extract results
|
||||
results = watsonx_response.get("results", [])
|
||||
data = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
embedding = result.get("embedding", [])
|
||||
data.append(
|
||||
EmbeddingData(
|
||||
embedding=embedding,
|
||||
index=idx,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate usage
|
||||
input_token_count = watsonx_response.get("input_token_count", 0)
|
||||
usage = EmbeddingUsage(
|
||||
prompt_tokens=input_token_count,
|
||||
total_tokens=input_token_count,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=data,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def format_sse_event(data: str) -> str:
|
||||
"""Format data as Server-Sent Event.
|
||||
|
||||
Args:
|
||||
data: JSON string to send
|
||||
|
||||
Returns:
|
||||
Formatted SSE string
|
||||
"""
|
||||
return f"data: {data}\n\n"
|
||||
27
docker-compose.yml
Normal file
27
docker-compose.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
watsonx-openai-proxy:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- IBM_CLOUD_API_KEY=${IBM_CLOUD_API_KEY}
|
||||
- WATSONX_PROJECT_ID=${WATSONX_PROJECT_ID}
|
||||
- WATSONX_CLUSTER=${WATSONX_CLUSTER:-us-south}
|
||||
- HOST=0.0.0.0
|
||||
- PORT=8000
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
- API_KEY=${API_KEY:-}
|
||||
- ALLOWED_ORIGINS=${ALLOWED_ORIGINS:-*}
|
||||
- MODEL_MAP_GPT4=${MODEL_MAP_GPT4:-ibm/granite-4-h-small}
|
||||
- MODEL_MAP_GPT35=${MODEL_MAP_GPT35:-ibm/granite-3-8b-instruct}
|
||||
- MODEL_MAP_GPT4_TURBO=${MODEL_MAP_GPT4_TURBO:-meta-llama/llama-3-3-70b-instruct}
|
||||
- MODEL_MAP_TEXT_EMBEDDING_ADA_002=${MODEL_MAP_TEXT_EMBEDDING_ADA_002:-ibm/slate-125m-english-rtrvr}
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8000/health')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 5s
|
||||
183
example_usage.py
Normal file
183
example_usage.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Example usage of watsonx-openai-proxy with OpenAI Python SDK."""
|
||||
|
||||
import os
|
||||
from openai import OpenAI
|
||||
|
||||
# Configure the client to use the proxy
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key=os.getenv("API_KEY", "not-needed-if-proxy-has-no-auth"),
|
||||
)
|
||||
|
||||
|
||||
def example_chat_completion():
|
||||
"""Example: Basic chat completion."""
|
||||
print("\n=== Chat Completion Example ===")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="ibm/granite-3-8b-instruct",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
print(f"Tokens used: {response.usage.total_tokens}")
|
||||
|
||||
|
||||
def example_streaming_chat():
|
||||
"""Example: Streaming chat completion."""
|
||||
print("\n=== Streaming Chat Example ===")
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model="ibm/granite-3-8b-instruct",
|
||||
messages=[
|
||||
{"role": "user", "content": "Tell me a short story about a robot."},
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
print("Response: ", end="", flush=True)
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
def example_with_model_mapping():
|
||||
"""Example: Using mapped model names."""
|
||||
print("\n=== Model Mapping Example ===")
|
||||
|
||||
# If you configured MODEL_MAP_GPT4=ibm/granite-4-h-small in .env
|
||||
# you can use "gpt-4" and it will be mapped automatically
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4", # This gets mapped to ibm/granite-4-h-small
|
||||
messages=[
|
||||
{"role": "user", "content": "Explain quantum computing in one sentence."},
|
||||
],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
print(f"Response: {response.choices[0].message.content}")
|
||||
|
||||
|
||||
def example_embeddings():
|
||||
"""Example: Generate embeddings."""
|
||||
print("\n=== Embeddings Example ===")
|
||||
|
||||
response = client.embeddings.create(
|
||||
model="ibm/slate-125m-english-rtrvr",
|
||||
input=[
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
],
|
||||
)
|
||||
|
||||
print(f"Generated {len(response.data)} embeddings")
|
||||
print(f"Embedding dimension: {len(response.data[0].embedding)}")
|
||||
print(f"First embedding (first 5 values): {response.data[0].embedding[:5]}")
|
||||
|
||||
|
||||
def example_list_models():
|
||||
"""Example: List available models."""
|
||||
print("\n=== List Models Example ===")
|
||||
|
||||
models = client.models.list()
|
||||
|
||||
print(f"Available models: {len(models.data)}")
|
||||
print("\nFirst 5 models:")
|
||||
for model in models.data[:5]:
|
||||
print(f" - {model.id}")
|
||||
|
||||
|
||||
def example_completion_legacy():
|
||||
"""Example: Legacy completion endpoint."""
|
||||
print("\n=== Legacy Completion Example ===")
|
||||
|
||||
response = client.completions.create(
|
||||
model="ibm/granite-3-8b-instruct",
|
||||
prompt="Once upon a time, in a land far away,",
|
||||
max_tokens=50,
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
print(f"Completion: {response.choices[0].text}")
|
||||
|
||||
|
||||
def example_with_functions():
|
||||
"""Example: Function calling (if supported by model)."""
|
||||
print("\n=== Function Calling Example ===")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model="ibm/granite-3-8b-instruct",
|
||||
messages=[
|
||||
{"role": "user", "content": "What's the weather like in Boston?"},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
if message.tool_calls:
|
||||
print(f"Function called: {message.tool_calls[0].function.name}")
|
||||
print(f"Arguments: {message.tool_calls[0].function.arguments}")
|
||||
else:
|
||||
print(f"Response: {message.content}")
|
||||
except Exception as e:
|
||||
print(f"Function calling may not be supported by this model: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("watsonx-openai-proxy Usage Examples")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Run examples
|
||||
example_chat_completion()
|
||||
example_streaming_chat()
|
||||
example_embeddings()
|
||||
example_list_models()
|
||||
example_completion_legacy()
|
||||
|
||||
# Optional examples (may require specific configuration)
|
||||
# example_with_model_mapping()
|
||||
# example_with_functions()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("All examples completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}")
|
||||
print("\nMake sure:")
|
||||
print("1. The proxy server is running (python -m app.main)")
|
||||
print("2. Your .env file is configured correctly")
|
||||
print("3. You have the OpenAI Python SDK installed (pip install openai)")
|
||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
fastapi==0.115.0
|
||||
uvicorn[standard]==0.32.0
|
||||
pydantic==2.9.2
|
||||
pydantic-settings==2.6.0
|
||||
httpx==0.27.2
|
||||
python-dotenv==1.0.1
|
||||
python-multipart==0.0.12
|
||||
87
tests/test_basic.py
Normal file
87
tests/test_basic.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Basic tests for watsonx-openai-proxy."""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_health_check():
|
||||
"""Test the health check endpoint."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "cluster" in data
|
||||
|
||||
|
||||
def test_root_endpoint():
|
||||
"""Test the root endpoint."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["service"] == "watsonx-openai-proxy"
|
||||
assert "endpoints" in data
|
||||
|
||||
|
||||
def test_list_models():
|
||||
"""Test listing available models."""
|
||||
response = client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["object"] == "list"
|
||||
assert len(data["data"]) > 0
|
||||
assert all(model["object"] == "model" for model in data["data"])
|
||||
|
||||
|
||||
def test_retrieve_model():
|
||||
"""Test retrieving a specific model."""
|
||||
response = client.get("/v1/models/ibm/granite-3-8b-instruct")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == "ibm/granite-3-8b-instruct"
|
||||
assert data["object"] == "model"
|
||||
|
||||
|
||||
def test_retrieve_nonexistent_model():
|
||||
"""Test retrieving a model that doesn't exist."""
|
||||
response = client.get("/v1/models/nonexistent-model")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# Note: The following tests require valid IBM Cloud credentials
|
||||
# and should be run with pytest markers or in integration tests
|
||||
|
||||
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
|
||||
def test_chat_completion():
|
||||
"""Test chat completion endpoint."""
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "ibm/granite-3-8b-instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello!"}
|
||||
],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "choices" in data
|
||||
assert len(data["choices"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
|
||||
def test_embeddings():
|
||||
"""Test embeddings endpoint."""
|
||||
response = client.post(
|
||||
"/v1/embeddings",
|
||||
json={
|
||||
"model": "ibm/slate-125m-english-rtrvr",
|
||||
"input": "Test text",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert len(data["data"]) > 0
|
||||
Reference in New Issue
Block a user