Add AGENTS.md documentation for AI agent guidance

This commit is contained in:
2026-02-23 09:59:52 -05:00
commit 2e2b817435
21 changed files with 2513 additions and 0 deletions

3
app/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""watsonx-openai-proxy application package."""
__version__ = "1.0.0"

91
app/config.py Normal file
View File

@@ -0,0 +1,91 @@
"""Configuration management for watsonx-openai-proxy."""
import os
from typing import Dict, Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# IBM Cloud Configuration
ibm_cloud_api_key: str
watsonx_project_id: str
watsonx_cluster: str = "us-south"
# Server Configuration
host: str = "0.0.0.0"
port: int = 8000
log_level: str = "info"
# API Configuration
api_key: Optional[str] = None
allowed_origins: str = "*"
# Token Management
token_refresh_interval: int = 3000 # 50 minutes in seconds
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="allow", # Allow extra fields for model mapping
)
@property
def watsonx_base_url(self) -> str:
"""Construct the watsonx.ai base URL from cluster."""
return f"https://{self.watsonx_cluster}.ml.cloud.ibm.com/ml/v1"
@property
def cors_origins(self) -> list[str]:
"""Parse CORS origins from comma-separated string."""
if self.allowed_origins == "*":
return ["*"]
return [origin.strip() for origin in self.allowed_origins.split(",")]
def get_model_mapping(self) -> Dict[str, str]:
"""Extract model mappings from environment variables.
Looks for variables like MODEL_MAP_GPT4=ibm/granite-4-h-small
and creates a mapping dict.
"""
import re
mapping = {}
# Check os.environ first
for key, value in os.environ.items():
if key.startswith("MODEL_MAP_"):
model_name = key.replace("MODEL_MAP_", "")
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
openai_model = model_name.lower().replace("_", "-")
mapping[openai_model] = value
# Also check pydantic's extra fields (from .env file)
# These come in as lowercase: model_map_gpt4 instead of MODEL_MAP_GPT4
extra = getattr(self, '__pydantic_extra__', {}) or {}
for key, value in extra.items():
if key.startswith("model_map_"):
# Convert back to uppercase for processing
model_name = key.replace("model_map_", "").upper()
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
openai_model = model_name.lower().replace("_", "-")
mapping[openai_model] = value
return mapping
def map_model(self, openai_model: str) -> str:
"""Map an OpenAI model name to a watsonx model ID.
Args:
openai_model: OpenAI model name (e.g., "gpt-4", "gpt-3.5-turbo")
Returns:
Corresponding watsonx model ID, or the original name if no mapping exists
"""
model_map = self.get_model_mapping()
return model_map.get(openai_model, openai_model)
# Global settings instance
settings = Settings()

157
app/main.py Normal file
View File

@@ -0,0 +1,157 @@
"""Main FastAPI application for watsonx-openai-proxy."""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.config import settings
from app.routers import chat, completions, embeddings, models
from app.services.watsonx_service import watsonx_service
# Configure logging
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan events."""
# Startup
logger.info("Starting watsonx-openai-proxy...")
logger.info(f"Cluster: {settings.watsonx_cluster}")
logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...")
# Initialize token
try:
await watsonx_service._refresh_token()
logger.info("Initial bearer token obtained successfully")
except Exception as e:
logger.error(f"Failed to obtain initial bearer token: {e}")
raise
yield
# Shutdown
logger.info("Shutting down watsonx-openai-proxy...")
await watsonx_service.close()
# Create FastAPI app
app = FastAPI(
title="watsonx-openai-proxy",
description="OpenAI-compatible API proxy for IBM watsonx.ai",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Optional API key authentication middleware
@app.middleware("http")
async def authenticate(request: Request, call_next):
"""Authenticate requests if API key is configured."""
if settings.api_key:
# Skip authentication for health check
if request.url.path == "/health":
return await call_next(request)
# Check Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Missing Authorization header",
"type": "authentication_error",
"code": "missing_authorization",
}
},
)
# Validate API key
if not auth_header.startswith("Bearer "):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Invalid Authorization header format",
"type": "authentication_error",
"code": "invalid_authorization_format",
}
},
)
token = auth_header[7:] # Remove "Bearer " prefix
if token != settings.api_key:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Invalid API key",
"type": "authentication_error",
"code": "invalid_api_key",
}
},
)
return await call_next(request)
# Include routers
app.include_router(chat.router, tags=["Chat"])
app.include_router(completions.router, tags=["Completions"])
app.include_router(embeddings.router, tags=["Embeddings"])
app.include_router(models.router, tags=["Models"])
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"service": "watsonx-openai-proxy",
"cluster": settings.watsonx_cluster,
}
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"service": "watsonx-openai-proxy",
"description": "OpenAI-compatible API proxy for IBM watsonx.ai",
"version": "1.0.0",
"endpoints": {
"chat": "/v1/chat/completions",
"completions": "/v1/completions",
"embeddings": "/v1/embeddings",
"models": "/v1/models",
"health": "/health",
},
"documentation": "/docs",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
log_level=settings.log_level,
reload=False,
)

31
app/models/__init__.py Normal file
View File

@@ -0,0 +1,31 @@
"""OpenAI-compatible data models."""
from app.models.openai_models import (
ChatMessage,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionChunk,
CompletionRequest,
CompletionResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelsResponse,
ModelInfo,
ErrorResponse,
ErrorDetail,
)
__all__ = [
"ChatMessage",
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionChunk",
"CompletionRequest",
"CompletionResponse",
"EmbeddingRequest",
"EmbeddingResponse",
"ModelsResponse",
"ModelInfo",
"ErrorResponse",
"ErrorDetail",
]

213
app/models/openai_models.py Normal file
View File

@@ -0,0 +1,213 @@
"""OpenAI-compatible request and response models."""
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
# ============================================================================
# Chat Completions Models
# ============================================================================
class ChatMessage(BaseModel):
"""A chat message in the conversation."""
role: Literal["system", "user", "assistant", "function", "tool"]
content: Optional[str] = None
name: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
class FunctionCall(BaseModel):
"""Function call specification."""
name: str
arguments: str
class ToolCall(BaseModel):
"""Tool call specification."""
id: str
type: Literal["function"]
function: FunctionCall
class ChatCompletionRequest(BaseModel):
"""OpenAI chat completion request."""
model: str
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
n: Optional[int] = Field(default=1, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=None, ge=1)
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
functions: Optional[List[Dict[str, Any]]] = None
function_call: Optional[Union[str, Dict[str, str]]] = None
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
class ChatCompletionChoice(BaseModel):
"""A single chat completion choice."""
index: int
message: ChatMessage
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
class ChatCompletionUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
"""OpenAI chat completion response."""
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
system_fingerprint: Optional[str] = None
class ChatCompletionChunkDelta(BaseModel):
"""Delta content in streaming response."""
role: Optional[str] = None
content: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
class ChatCompletionChunkChoice(BaseModel):
"""A single streaming chunk choice."""
index: int
delta: ChatCompletionChunkDelta
finish_reason: Optional[str] = None
logprobs: Optional[Dict[str, Any]] = None
class ChatCompletionChunk(BaseModel):
"""OpenAI streaming chat completion chunk."""
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
choices: List[ChatCompletionChunkChoice]
system_fingerprint: Optional[str] = None
# ============================================================================
# Completions Models (Legacy)
# ============================================================================
class CompletionRequest(BaseModel):
"""OpenAI completion request (legacy)."""
model: str
prompt: Union[str, List[str], List[int], List[List[int]]]
suffix: Optional[str] = None
max_tokens: Optional[int] = Field(default=16, ge=1)
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
n: Optional[int] = Field(default=1, ge=1)
stream: Optional[bool] = False
logprobs: Optional[int] = Field(default=None, ge=0, le=5)
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
best_of: Optional[int] = Field(default=1, ge=1)
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
class CompletionChoice(BaseModel):
"""A single completion choice."""
text: str
index: int
logprobs: Optional[Dict[str, Any]] = None
finish_reason: Optional[str] = None
class CompletionResponse(BaseModel):
"""OpenAI completion response."""
id: str
object: Literal["text_completion"] = "text_completion"
created: int
model: str
choices: List[CompletionChoice]
usage: ChatCompletionUsage
# ============================================================================
# Embeddings Models
# ============================================================================
class EmbeddingRequest(BaseModel):
"""OpenAI embedding request."""
model: str
input: Union[str, List[str], List[int], List[List[int]]]
encoding_format: Optional[Literal["float", "base64"]] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
class EmbeddingData(BaseModel):
"""A single embedding result."""
object: Literal["embedding"] = "embedding"
embedding: List[float]
index: int
class EmbeddingUsage(BaseModel):
"""Token usage for embeddings."""
prompt_tokens: int
total_tokens: int
class EmbeddingResponse(BaseModel):
"""OpenAI embedding response."""
object: Literal["list"] = "list"
data: List[EmbeddingData]
model: str
usage: EmbeddingUsage
# ============================================================================
# Models List
# ============================================================================
class ModelInfo(BaseModel):
"""Information about a single model."""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
class ModelsResponse(BaseModel):
"""List of available models."""
object: Literal["list"] = "list"
data: List[ModelInfo]
# ============================================================================
# Error Models
# ============================================================================
class ErrorDetail(BaseModel):
"""Error detail information."""
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
class ErrorResponse(BaseModel):
"""OpenAI error response."""
error: ErrorDetail

5
app/routers/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""API routers for watsonx-openai-proxy."""
from app.routers import chat, completions, embeddings, models
__all__ = ["chat", "completions", "embeddings", "models"]

156
app/routers/chat.py Normal file
View File

@@ -0,0 +1,156 @@
"""Chat completions endpoint router."""
import json
import uuid
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from app.models.openai_models import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import (
transform_messages_to_watsonx,
transform_tools_to_watsonx,
transform_watsonx_to_openai_chat,
transform_watsonx_to_openai_chat_chunk,
format_sse_event,
)
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/v1/chat/completions",
response_model=Union[ChatCompletionResponse, ErrorResponse],
responses={
200: {"model": ChatCompletionResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_chat_completion(
request: ChatCompletionRequest,
http_request: Request,
):
"""Create a chat completion using OpenAI-compatible API.
This endpoint accepts OpenAI-formatted requests and translates them
to watsonx.ai API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
# Transform messages
watsonx_messages = transform_messages_to_watsonx(request.messages)
# Transform tools if present
watsonx_tools = transform_tools_to_watsonx(request.tools)
# Handle streaming
if request.stream:
return StreamingResponse(
stream_chat_completion(
watsonx_model,
watsonx_messages,
request,
watsonx_tools,
),
media_type="text/event-stream",
)
# Non-streaming response
watsonx_response = await watsonx_service.chat_completion(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
)
# Transform response
openai_response = transform_watsonx_to_openai_chat(
watsonx_response,
request.model,
)
return openai_response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)
async def stream_chat_completion(
watsonx_model: str,
watsonx_messages: list,
request: ChatCompletionRequest,
watsonx_tools: list = None,
):
"""Stream chat completion responses.
Args:
watsonx_model: The watsonx model ID
watsonx_messages: Transformed messages
request: Original OpenAI request
watsonx_tools: Transformed tools
Yields:
Server-Sent Events with chat completion chunks
"""
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
try:
async for chunk in watsonx_service.chat_completion_stream(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
):
# Transform chunk to OpenAI format
openai_chunk = transform_watsonx_to_openai_chat_chunk(
chunk,
request.model,
request_id,
)
# Send as SSE
yield format_sse_event(openai_chunk.model_dump_json())
# Send [DONE] message
yield format_sse_event("[DONE]")
except Exception as e:
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
error_response = ErrorResponse(
error=ErrorDetail(
message=str(e),
type="internal_error",
code="stream_error",
)
)
yield format_sse_event(error_response.model_dump_json())

109
app/routers/completions.py Normal file
View File

@@ -0,0 +1,109 @@
"""Text completions endpoint router (legacy)."""
import uuid
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from app.models.openai_models import (
CompletionRequest,
CompletionResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import transform_watsonx_to_openai_completion
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/v1/completions",
response_model=Union[CompletionResponse, ErrorResponse],
responses={
200: {"model": CompletionResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_completion(
request: CompletionRequest,
http_request: Request,
):
"""Create a text completion using OpenAI-compatible API (legacy).
This endpoint accepts OpenAI-formatted completion requests and translates
them to watsonx.ai text generation API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Completion request: {request.model} -> {watsonx_model}")
# Handle prompt (can be string or list)
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Prompt cannot be empty",
"type": "invalid_request_error",
"code": "invalid_prompt",
}
},
)
# For now, just use the first prompt
# TODO: Handle multiple prompts with n parameter
prompt = request.prompt[0] if isinstance(request.prompt[0], str) else ""
else:
prompt = request.prompt
# Note: Streaming not implemented for completions yet
if request.stream:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Streaming not supported for completions endpoint",
"type": "invalid_request_error",
"code": "streaming_not_supported",
}
},
)
# Call watsonx text generation
watsonx_response = await watsonx_service.text_generation(
model_id=watsonx_model,
prompt=prompt,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
)
# Transform response
openai_response = transform_watsonx_to_openai_completion(
watsonx_response,
request.model,
)
return openai_response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in completion: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)

114
app/routers/embeddings.py Normal file
View File

@@ -0,0 +1,114 @@
"""Embeddings endpoint router."""
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from app.models.openai_models import (
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import transform_watsonx_to_openai_embeddings
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/v1/embeddings",
response_model=Union[EmbeddingResponse, ErrorResponse],
responses={
200: {"model": EmbeddingResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_embeddings(
request: EmbeddingRequest,
http_request: Request,
):
"""Create embeddings using OpenAI-compatible API.
This endpoint accepts OpenAI-formatted embedding requests and translates
them to watsonx.ai embeddings API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Embeddings request: {request.model} -> {watsonx_model}")
# Handle input (can be string or list)
if isinstance(request.input, str):
inputs = [request.input]
elif isinstance(request.input, list):
if len(request.input) == 0:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Input cannot be empty",
"type": "invalid_request_error",
"code": "invalid_input",
}
},
)
# Handle list of strings or list of token IDs
if isinstance(request.input[0], str):
inputs = request.input
else:
# Token IDs not supported yet
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Token ID input not supported",
"type": "invalid_request_error",
"code": "unsupported_input_type",
}
},
)
else:
raise HTTPException(
status_code=400,
detail={
"error": {
"message": "Invalid input type",
"type": "invalid_request_error",
"code": "invalid_input_type",
}
},
)
# Call watsonx embeddings
watsonx_response = await watsonx_service.embeddings(
model_id=watsonx_model,
inputs=inputs,
)
# Transform response
openai_response = transform_watsonx_to_openai_embeddings(
watsonx_response,
request.model,
)
return openai_response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in embeddings: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)

120
app/routers/models.py Normal file
View File

@@ -0,0 +1,120 @@
"""Models endpoint router."""
import time
from fastapi import APIRouter
from app.models.openai_models import ModelsResponse, ModelInfo
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
# Predefined list of available models
# This can be extended or made dynamic based on watsonx.ai API
AVAILABLE_MODELS = [
# Granite Models
"ibm/granite-3-1-8b-base",
"ibm/granite-3-2-8b-instruct",
"ibm/granite-3-3-8b-instruct",
"ibm/granite-3-8b-instruct",
"ibm/granite-4-h-small",
"ibm/granite-8b-code-instruct",
# Llama Models
"meta-llama/llama-3-1-70b-gptq",
"meta-llama/llama-3-1-8b",
"meta-llama/llama-3-2-11b-vision-instruct",
"meta-llama/llama-3-2-90b-vision-instruct",
"meta-llama/llama-3-3-70b-instruct",
"meta-llama/llama-3-405b-instruct",
"meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
# Mistral Models
"mistral-large-2512",
"mistralai/mistral-medium-2505",
"mistralai/mistral-small-3-1-24b-instruct-2503",
# Other Models
"openai/gpt-oss-120b",
# Embedding Models
"ibm/slate-125m-english-rtrvr",
"ibm/slate-30m-english-rtrvr",
]
@router.get(
"/v1/models",
response_model=ModelsResponse,
)
async def list_models():
"""List available models in OpenAI-compatible format.
Returns a list of models that can be used with the API.
Includes both the actual watsonx model IDs and any mapped names.
"""
created_time = int(time.time())
models = []
# Add all available watsonx models
for model_id in AVAILABLE_MODELS:
models.append(
ModelInfo(
id=model_id,
created=created_time,
owned_by="ibm-watsonx",
)
)
# Add mapped model names (e.g., gpt-4 -> ibm/granite-4-h-small)
model_mapping = settings.get_model_mapping()
for openai_name, watsonx_id in model_mapping.items():
if watsonx_id in AVAILABLE_MODELS:
models.append(
ModelInfo(
id=openai_name,
created=created_time,
owned_by="ibm-watsonx",
)
)
return ModelsResponse(data=models)
@router.get(
"/v1/models/{model_id}",
response_model=ModelInfo,
)
async def retrieve_model(model_id: str):
"""Retrieve information about a specific model.
Args:
model_id: The model ID to retrieve
Returns:
Model information
"""
# Map the model if needed
watsonx_model = settings.map_model(model_id)
# Check if model exists
if watsonx_model not in AVAILABLE_MODELS:
from fastapi import HTTPException
raise HTTPException(
status_code=404,
detail={
"error": {
"message": f"Model '{model_id}' not found",
"type": "invalid_request_error",
"code": "model_not_found",
}
},
)
return ModelInfo(
id=model_id,
created=int(time.time()),
owned_by="ibm-watsonx",
)

5
app/services/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Services for watsonx-openai-proxy."""
from app.services.watsonx_service import watsonx_service, WatsonxService
__all__ = ["watsonx_service", "WatsonxService"]

View File

@@ -0,0 +1,316 @@
"""Service for interacting with IBM watsonx.ai APIs."""
import asyncio
import time
from typing import AsyncIterator, Dict, List, Optional
import httpx
from app.config import settings
import logging
logger = logging.getLogger(__name__)
class WatsonxService:
"""Service for managing watsonx.ai API interactions."""
def __init__(self):
self.base_url = settings.watsonx_base_url
self.project_id = settings.watsonx_project_id
self.api_key = settings.ibm_cloud_api_key
self._bearer_token: Optional[str] = None
self._token_expiry: Optional[float] = None
self._token_lock = asyncio.Lock()
self._client: Optional[httpx.AsyncClient] = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
self._client = httpx.AsyncClient(timeout=300.0)
return self._client
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _refresh_token(self) -> str:
"""Get a fresh bearer token from IBM Cloud IAM."""
async with self._token_lock:
# Check if token is still valid
if self._bearer_token and self._token_expiry:
if time.time() < self._token_expiry - 300: # 5 min buffer
return self._bearer_token
logger.info("Refreshing IBM Cloud bearer token...")
client = await self._get_client()
response = await client.post(
"https://iam.cloud.ibm.com/identity/token",
headers={"Content-Type": "application/x-www-form-urlencoded"},
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}",
)
if response.status_code != 200:
raise Exception(f"Failed to get bearer token: {response.text}")
data = response.json()
self._bearer_token = data["access_token"]
self._token_expiry = time.time() + data.get("expires_in", 3600)
logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds")
return self._bearer_token
async def _get_headers(self) -> Dict[str, str]:
"""Get headers with valid bearer token."""
token = await self._refresh_token()
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
async def chat_completion(
self,
model_id: str,
messages: List[Dict],
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
stream: bool = False,
tools: Optional[List[Dict]] = None,
**kwargs,
) -> Dict:
"""Create a chat completion using watsonx.ai.
Args:
model_id: The watsonx model ID
messages: List of chat messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
stream: Whether to stream the response
tools: Tool/function definitions
**kwargs: Additional parameters
Returns:
Chat completion response
"""
headers = await self._get_headers()
client = await self._get_client()
# Build watsonx request
payload = {
"model_id": model_id,
"project_id": self.project_id,
"messages": messages,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
if tools:
payload["tools"] = tools
url = f"{self.base_url}/text/chat"
params = {"version": "2024-02-13"}
if stream:
params["stream"] = "true"
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
async def chat_completion_stream(
self,
model_id: str,
messages: List[Dict],
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
tools: Optional[List[Dict]] = None,
**kwargs,
) -> AsyncIterator[Dict]:
"""Stream a chat completion using watsonx.ai.
Args:
model_id: The watsonx model ID
messages: List of chat messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
tools: Tool/function definitions
**kwargs: Additional parameters
Yields:
Chat completion chunks
"""
headers = await self._get_headers()
client = await self._get_client()
# Build watsonx request
payload = {
"model_id": model_id,
"project_id": self.project_id,
"messages": messages,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
if tools:
payload["tools"] = tools
url = f"{self.base_url}/text/chat_stream"
params = {"version": "2024-02-13"}
async with client.stream(
"POST",
url,
headers=headers,
json=payload,
params=params,
) as response:
if response.status_code != 200:
text = await response.aread()
raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}")
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
if data.strip() == "[DONE]":
break
try:
import json
yield json.loads(data)
except json.JSONDecodeError:
continue
async def text_generation(
self,
model_id: str,
prompt: str,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
**kwargs,
) -> Dict:
"""Generate text completion using watsonx.ai.
Args:
model_id: The watsonx model ID
prompt: The input prompt
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
**kwargs: Additional parameters
Returns:
Text generation response
"""
headers = await self._get_headers()
client = await self._get_client()
payload = {
"model_id": model_id,
"project_id": self.project_id,
"input": prompt,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_new_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
url = f"{self.base_url}/text/generation"
params = {"version": "2024-02-13"}
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
async def embeddings(
self,
model_id: str,
inputs: List[str],
**kwargs,
) -> Dict:
"""Generate embeddings using watsonx.ai.
Args:
model_id: The watsonx embedding model ID
inputs: List of texts to embed
**kwargs: Additional parameters
Returns:
Embeddings response
"""
headers = await self._get_headers()
client = await self._get_client()
payload = {
"model_id": model_id,
"project_id": self.project_id,
"inputs": inputs,
}
url = f"{self.base_url}/text/embeddings"
params = {"version": "2024-02-13"}
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
# Global service instance
watsonx_service = WatsonxService()

21
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,21 @@
"""Utility functions for watsonx-openai-proxy."""
from app.utils.transformers import (
transform_messages_to_watsonx,
transform_tools_to_watsonx,
transform_watsonx_to_openai_chat,
transform_watsonx_to_openai_chat_chunk,
transform_watsonx_to_openai_completion,
transform_watsonx_to_openai_embeddings,
format_sse_event,
)
__all__ = [
"transform_messages_to_watsonx",
"transform_tools_to_watsonx",
"transform_watsonx_to_openai_chat",
"transform_watsonx_to_openai_chat_chunk",
"transform_watsonx_to_openai_completion",
"transform_watsonx_to_openai_embeddings",
"format_sse_event",
]

272
app/utils/transformers.py Normal file
View File

@@ -0,0 +1,272 @@
"""Utilities for transforming between OpenAI and watsonx formats."""
import time
import uuid
from typing import Any, Dict, List, Optional
from app.models.openai_models import (
ChatMessage,
ChatCompletionChoice,
ChatCompletionUsage,
ChatCompletionResponse,
ChatCompletionChunk,
ChatCompletionChunkChoice,
ChatCompletionChunkDelta,
CompletionChoice,
CompletionResponse,
EmbeddingData,
EmbeddingUsage,
EmbeddingResponse,
)
def transform_messages_to_watsonx(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""Transform OpenAI messages to watsonx format.
Args:
messages: List of OpenAI ChatMessage objects
Returns:
List of watsonx-compatible message dicts
"""
watsonx_messages = []
for msg in messages:
watsonx_msg = {
"role": msg.role,
}
if msg.content:
watsonx_msg["content"] = msg.content
if msg.name:
watsonx_msg["name"] = msg.name
if msg.tool_calls:
watsonx_msg["tool_calls"] = msg.tool_calls
if msg.function_call:
watsonx_msg["function_call"] = msg.function_call
watsonx_messages.append(watsonx_msg)
return watsonx_messages
def transform_tools_to_watsonx(tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
"""Transform OpenAI tools to watsonx format.
Args:
tools: List of OpenAI tool definitions
Returns:
List of watsonx-compatible tool definitions
"""
if not tools:
return None
# watsonx uses similar format to OpenAI for tools
return tools
def transform_watsonx_to_openai_chat(
watsonx_response: Dict[str, Any],
model: str,
request_id: Optional[str] = None,
) -> ChatCompletionResponse:
"""Transform watsonx chat response to OpenAI format.
Args:
watsonx_response: Response from watsonx chat API
model: Model name to include in response
request_id: Optional request ID, generates one if not provided
Returns:
OpenAI-compatible ChatCompletionResponse
"""
response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
# Extract choices
choices = []
watsonx_choices = watsonx_response.get("choices", [])
for idx, choice in enumerate(watsonx_choices):
message_data = choice.get("message", {})
message = ChatMessage(
role=message_data.get("role", "assistant"),
content=message_data.get("content"),
tool_calls=message_data.get("tool_calls"),
function_call=message_data.get("function_call"),
)
choices.append(
ChatCompletionChoice(
index=idx,
message=message,
finish_reason=choice.get("finish_reason"),
)
)
# Extract usage
usage_data = watsonx_response.get("usage", {})
usage = ChatCompletionUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
return ChatCompletionResponse(
id=response_id,
created=created,
model=model,
choices=choices,
usage=usage,
)
def transform_watsonx_to_openai_chat_chunk(
watsonx_chunk: Dict[str, Any],
model: str,
request_id: str,
) -> ChatCompletionChunk:
"""Transform watsonx streaming chunk to OpenAI format.
Args:
watsonx_chunk: Streaming chunk from watsonx
model: Model name to include in response
request_id: Request ID for this stream
Returns:
OpenAI-compatible ChatCompletionChunk
"""
created = int(time.time())
# Extract choices
choices = []
watsonx_choices = watsonx_chunk.get("choices", [])
for idx, choice in enumerate(watsonx_choices):
delta_data = choice.get("delta", {})
delta = ChatCompletionChunkDelta(
role=delta_data.get("role"),
content=delta_data.get("content"),
tool_calls=delta_data.get("tool_calls"),
function_call=delta_data.get("function_call"),
)
choices.append(
ChatCompletionChunkChoice(
index=idx,
delta=delta,
finish_reason=choice.get("finish_reason"),
)
)
return ChatCompletionChunk(
id=request_id,
created=created,
model=model,
choices=choices,
)
def transform_watsonx_to_openai_completion(
watsonx_response: Dict[str, Any],
model: str,
request_id: Optional[str] = None,
) -> CompletionResponse:
"""Transform watsonx text generation response to OpenAI completion format.
Args:
watsonx_response: Response from watsonx text generation API
model: Model name to include in response
request_id: Optional request ID, generates one if not provided
Returns:
OpenAI-compatible CompletionResponse
"""
response_id = request_id or f"cmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
# Extract results
results = watsonx_response.get("results", [])
choices = []
for idx, result in enumerate(results):
choices.append(
CompletionChoice(
text=result.get("generated_text", ""),
index=idx,
finish_reason=result.get("stop_reason"),
)
)
# Extract usage
usage_data = watsonx_response.get("usage", {})
usage = ChatCompletionUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("generated_tokens", 0),
total_tokens=usage_data.get("prompt_tokens", 0) + usage_data.get("generated_tokens", 0),
)
return CompletionResponse(
id=response_id,
created=created,
model=model,
choices=choices,
usage=usage,
)
def transform_watsonx_to_openai_embeddings(
watsonx_response: Dict[str, Any],
model: str,
) -> EmbeddingResponse:
"""Transform watsonx embeddings response to OpenAI format.
Args:
watsonx_response: Response from watsonx embeddings API
model: Model name to include in response
Returns:
OpenAI-compatible EmbeddingResponse
"""
# Extract results
results = watsonx_response.get("results", [])
data = []
for idx, result in enumerate(results):
embedding = result.get("embedding", [])
data.append(
EmbeddingData(
embedding=embedding,
index=idx,
)
)
# Calculate usage
input_token_count = watsonx_response.get("input_token_count", 0)
usage = EmbeddingUsage(
prompt_tokens=input_token_count,
total_tokens=input_token_count,
)
return EmbeddingResponse(
data=data,
model=model,
usage=usage,
)
def format_sse_event(data: str) -> str:
"""Format data as Server-Sent Event.
Args:
data: JSON string to send
Returns:
Formatted SSE string
"""
return f"data: {data}\n\n"