Add AGENTS.md documentation for AI agent guidance
This commit is contained in:
3
app/__init__.py
Normal file
3
app/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""watsonx-openai-proxy application package."""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
91
app/config.py
Normal file
91
app/config.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Configuration management for watsonx-openai-proxy."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# IBM Cloud Configuration
|
||||
ibm_cloud_api_key: str
|
||||
watsonx_project_id: str
|
||||
watsonx_cluster: str = "us-south"
|
||||
|
||||
# Server Configuration
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
log_level: str = "info"
|
||||
|
||||
# API Configuration
|
||||
api_key: Optional[str] = None
|
||||
allowed_origins: str = "*"
|
||||
|
||||
# Token Management
|
||||
token_refresh_interval: int = 3000 # 50 minutes in seconds
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="allow", # Allow extra fields for model mapping
|
||||
)
|
||||
|
||||
@property
|
||||
def watsonx_base_url(self) -> str:
|
||||
"""Construct the watsonx.ai base URL from cluster."""
|
||||
return f"https://{self.watsonx_cluster}.ml.cloud.ibm.com/ml/v1"
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
"""Parse CORS origins from comma-separated string."""
|
||||
if self.allowed_origins == "*":
|
||||
return ["*"]
|
||||
return [origin.strip() for origin in self.allowed_origins.split(",")]
|
||||
|
||||
def get_model_mapping(self) -> Dict[str, str]:
|
||||
"""Extract model mappings from environment variables.
|
||||
|
||||
Looks for variables like MODEL_MAP_GPT4=ibm/granite-4-h-small
|
||||
and creates a mapping dict.
|
||||
"""
|
||||
import re
|
||||
mapping = {}
|
||||
|
||||
# Check os.environ first
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith("MODEL_MAP_"):
|
||||
model_name = key.replace("MODEL_MAP_", "")
|
||||
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
||||
openai_model = model_name.lower().replace("_", "-")
|
||||
mapping[openai_model] = value
|
||||
|
||||
# Also check pydantic's extra fields (from .env file)
|
||||
# These come in as lowercase: model_map_gpt4 instead of MODEL_MAP_GPT4
|
||||
extra = getattr(self, '__pydantic_extra__', {}) or {}
|
||||
for key, value in extra.items():
|
||||
if key.startswith("model_map_"):
|
||||
# Convert back to uppercase for processing
|
||||
model_name = key.replace("model_map_", "").upper()
|
||||
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
||||
openai_model = model_name.lower().replace("_", "-")
|
||||
mapping[openai_model] = value
|
||||
|
||||
return mapping
|
||||
|
||||
def map_model(self, openai_model: str) -> str:
|
||||
"""Map an OpenAI model name to a watsonx model ID.
|
||||
|
||||
Args:
|
||||
openai_model: OpenAI model name (e.g., "gpt-4", "gpt-3.5-turbo")
|
||||
|
||||
Returns:
|
||||
Corresponding watsonx model ID, or the original name if no mapping exists
|
||||
"""
|
||||
model_map = self.get_model_mapping()
|
||||
return model_map.get(openai_model, openai_model)
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
157
app/main.py
Normal file
157
app/main.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Main FastAPI application for watsonx-openai-proxy."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.config import settings
|
||||
from app.routers import chat, completions, embeddings, models
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan events."""
|
||||
# Startup
|
||||
logger.info("Starting watsonx-openai-proxy...")
|
||||
logger.info(f"Cluster: {settings.watsonx_cluster}")
|
||||
logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...")
|
||||
|
||||
# Initialize token
|
||||
try:
|
||||
await watsonx_service._refresh_token()
|
||||
logger.info("Initial bearer token obtained successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to obtain initial bearer token: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down watsonx-openai-proxy...")
|
||||
await watsonx_service.close()
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="watsonx-openai-proxy",
|
||||
description="OpenAI-compatible API proxy for IBM watsonx.ai",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Optional API key authentication middleware
|
||||
@app.middleware("http")
|
||||
async def authenticate(request: Request, call_next):
|
||||
"""Authenticate requests if API key is configured."""
|
||||
if settings.api_key:
|
||||
# Skip authentication for health check
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Missing Authorization header",
|
||||
"type": "authentication_error",
|
||||
"code": "missing_authorization",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Validate API key
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Invalid Authorization header format",
|
||||
"type": "authentication_error",
|
||||
"code": "invalid_authorization_format",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
if token != settings.api_key:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Invalid API key",
|
||||
"type": "authentication_error",
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(chat.router, tags=["Chat"])
|
||||
app.include_router(completions.router, tags=["Completions"])
|
||||
app.include_router(embeddings.router, tags=["Embeddings"])
|
||||
app.include_router(models.router, tags=["Models"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "watsonx-openai-proxy",
|
||||
"cluster": settings.watsonx_cluster,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"service": "watsonx-openai-proxy",
|
||||
"description": "OpenAI-compatible API proxy for IBM watsonx.ai",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"chat": "/v1/chat/completions",
|
||||
"completions": "/v1/completions",
|
||||
"embeddings": "/v1/embeddings",
|
||||
"models": "/v1/models",
|
||||
"health": "/health",
|
||||
},
|
||||
"documentation": "/docs",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
log_level=settings.log_level,
|
||||
reload=False,
|
||||
)
|
||||
31
app/models/__init__.py
Normal file
31
app/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""OpenAI-compatible data models."""
|
||||
|
||||
from app.models.openai_models import (
|
||||
ChatMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelsResponse,
|
||||
ModelInfo,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatCompletionResponse",
|
||||
"ChatCompletionChunk",
|
||||
"CompletionRequest",
|
||||
"CompletionResponse",
|
||||
"EmbeddingRequest",
|
||||
"EmbeddingResponse",
|
||||
"ModelsResponse",
|
||||
"ModelInfo",
|
||||
"ErrorResponse",
|
||||
"ErrorDetail",
|
||||
]
|
||||
213
app/models/openai_models.py
Normal file
213
app/models/openai_models.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""OpenAI-compatible request and response models."""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Chat Completions Models
|
||||
# ============================================================================
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""A chat message in the conversation."""
|
||||
role: Literal["system", "user", "assistant", "function", "tool"]
|
||||
content: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
function_call: Optional[Dict[str, Any]] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
"""Function call specification."""
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call specification."""
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""OpenAI chat completion request."""
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
|
||||
n: Optional[int] = Field(default=1, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = Field(default=None, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
functions: Optional[List[Dict[str, Any]]] = None
|
||||
function_call: Optional[Union[str, Dict[str, str]]] = None
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
"""A single chat completion choice."""
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionUsage(BaseModel):
|
||||
"""Token usage information."""
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
"""OpenAI chat completion response."""
|
||||
id: str
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionChoice]
|
||||
usage: ChatCompletionUsage
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionChunkDelta(BaseModel):
|
||||
"""Delta content in streaming response."""
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[Dict[str, Any]] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class ChatCompletionChunkChoice(BaseModel):
|
||||
"""A single streaming chunk choice."""
|
||||
index: int
|
||||
delta: ChatCompletionChunkDelta
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
"""OpenAI streaming chat completion chunk."""
|
||||
id: str
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[ChatCompletionChunkChoice]
|
||||
system_fingerprint: Optional[str] = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Completions Models (Legacy)
|
||||
# ============================================================================
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""OpenAI completion request (legacy)."""
|
||||
model: str
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]]
|
||||
suffix: Optional[str] = None
|
||||
max_tokens: Optional[int] = Field(default=16, ge=1)
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0, le=2)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0, le=1)
|
||||
n: Optional[int] = Field(default=1, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = Field(default=None, ge=0, le=5)
|
||||
echo: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
presence_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
frequency_penalty: Optional[float] = Field(default=0, ge=-2, le=2)
|
||||
best_of: Optional[int] = Field(default=1, ge=1)
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionChoice(BaseModel):
|
||||
"""A single completion choice."""
|
||||
text: str
|
||||
index: int
|
||||
logprobs: Optional[Dict[str, Any]] = None
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""OpenAI completion response."""
|
||||
id: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: ChatCompletionUsage
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Embeddings Models
|
||||
# ============================================================================
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""OpenAI embedding request."""
|
||||
model: str
|
||||
input: Union[str, List[str], List[int], List[List[int]]]
|
||||
encoding_format: Optional[Literal["float", "base64"]] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingData(BaseModel):
|
||||
"""A single embedding result."""
|
||||
object: Literal["embedding"] = "embedding"
|
||||
embedding: List[float]
|
||||
index: int
|
||||
|
||||
|
||||
class EmbeddingUsage(BaseModel):
|
||||
"""Token usage for embeddings."""
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""OpenAI embedding response."""
|
||||
object: Literal["list"] = "list"
|
||||
data: List[EmbeddingData]
|
||||
model: str
|
||||
usage: EmbeddingUsage
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Models List
|
||||
# ============================================================================
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about a single model."""
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
"""List of available models."""
|
||||
object: Literal["list"] = "list"
|
||||
data: List[ModelInfo]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Error Models
|
||||
# ============================================================================
|
||||
|
||||
class ErrorDetail(BaseModel):
|
||||
"""Error detail information."""
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""OpenAI error response."""
|
||||
error: ErrorDetail
|
||||
5
app/routers/__init__.py
Normal file
5
app/routers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""API routers for watsonx-openai-proxy."""
|
||||
|
||||
from app.routers import chat, completions, embeddings, models
|
||||
|
||||
__all__ = ["chat", "completions", "embeddings", "models"]
|
||||
156
app/routers/chat.py
Normal file
156
app/routers/chat.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Chat completions endpoint router."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Union
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from app.models.openai_models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
from app.utils.transformers import (
|
||||
transform_messages_to_watsonx,
|
||||
transform_tools_to_watsonx,
|
||||
transform_watsonx_to_openai_chat,
|
||||
transform_watsonx_to_openai_chat_chunk,
|
||||
format_sse_event,
|
||||
)
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
response_model=Union[ChatCompletionResponse, ErrorResponse],
|
||||
responses={
|
||||
200: {"model": ChatCompletionResponse},
|
||||
400: {"model": ErrorResponse},
|
||||
401: {"model": ErrorResponse},
|
||||
500: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def create_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
http_request: Request,
|
||||
):
|
||||
"""Create a chat completion using OpenAI-compatible API.
|
||||
|
||||
This endpoint accepts OpenAI-formatted requests and translates them
|
||||
to watsonx.ai API calls.
|
||||
"""
|
||||
try:
|
||||
# Map model name if needed
|
||||
watsonx_model = settings.map_model(request.model)
|
||||
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
|
||||
|
||||
# Transform messages
|
||||
watsonx_messages = transform_messages_to_watsonx(request.messages)
|
||||
|
||||
# Transform tools if present
|
||||
watsonx_tools = transform_tools_to_watsonx(request.tools)
|
||||
|
||||
# Handle streaming
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
stream_chat_completion(
|
||||
watsonx_model,
|
||||
watsonx_messages,
|
||||
request,
|
||||
watsonx_tools,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
watsonx_response = await watsonx_service.chat_completion(
|
||||
model_id=watsonx_model,
|
||||
messages=watsonx_messages,
|
||||
temperature=request.temperature or 1.0,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p or 1.0,
|
||||
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||
tools=watsonx_tools,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
openai_response = transform_watsonx_to_openai_chat(
|
||||
watsonx_response,
|
||||
request.model,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "internal_error",
|
||||
"code": "internal_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
watsonx_model: str,
|
||||
watsonx_messages: list,
|
||||
request: ChatCompletionRequest,
|
||||
watsonx_tools: list = None,
|
||||
):
|
||||
"""Stream chat completion responses.
|
||||
|
||||
Args:
|
||||
watsonx_model: The watsonx model ID
|
||||
watsonx_messages: Transformed messages
|
||||
request: Original OpenAI request
|
||||
watsonx_tools: Transformed tools
|
||||
|
||||
Yields:
|
||||
Server-Sent Events with chat completion chunks
|
||||
"""
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
|
||||
try:
|
||||
async for chunk in watsonx_service.chat_completion_stream(
|
||||
model_id=watsonx_model,
|
||||
messages=watsonx_messages,
|
||||
temperature=request.temperature or 1.0,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p or 1.0,
|
||||
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||
tools=watsonx_tools,
|
||||
):
|
||||
# Transform chunk to OpenAI format
|
||||
openai_chunk = transform_watsonx_to_openai_chat_chunk(
|
||||
chunk,
|
||||
request.model,
|
||||
request_id,
|
||||
)
|
||||
|
||||
# Send as SSE
|
||||
yield format_sse_event(openai_chunk.model_dump_json())
|
||||
|
||||
# Send [DONE] message
|
||||
yield format_sse_event("[DONE]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorDetail(
|
||||
message=str(e),
|
||||
type="internal_error",
|
||||
code="stream_error",
|
||||
)
|
||||
)
|
||||
yield format_sse_event(error_response.model_dump_json())
|
||||
109
app/routers/completions.py
Normal file
109
app/routers/completions.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Text completions endpoint router (legacy)."""
|
||||
|
||||
import uuid
|
||||
from typing import Union
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from app.models.openai_models import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
from app.utils.transformers import transform_watsonx_to_openai_completion
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
response_model=Union[CompletionResponse, ErrorResponse],
|
||||
responses={
|
||||
200: {"model": CompletionResponse},
|
||||
400: {"model": ErrorResponse},
|
||||
401: {"model": ErrorResponse},
|
||||
500: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def create_completion(
|
||||
request: CompletionRequest,
|
||||
http_request: Request,
|
||||
):
|
||||
"""Create a text completion using OpenAI-compatible API (legacy).
|
||||
|
||||
This endpoint accepts OpenAI-formatted completion requests and translates
|
||||
them to watsonx.ai text generation API calls.
|
||||
"""
|
||||
try:
|
||||
# Map model name if needed
|
||||
watsonx_model = settings.map_model(request.model)
|
||||
logger.info(f"Completion request: {request.model} -> {watsonx_model}")
|
||||
|
||||
# Handle prompt (can be string or list)
|
||||
if isinstance(request.prompt, list):
|
||||
if len(request.prompt) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Prompt cannot be empty",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_prompt",
|
||||
}
|
||||
},
|
||||
)
|
||||
# For now, just use the first prompt
|
||||
# TODO: Handle multiple prompts with n parameter
|
||||
prompt = request.prompt[0] if isinstance(request.prompt[0], str) else ""
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
# Note: Streaming not implemented for completions yet
|
||||
if request.stream:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Streaming not supported for completions endpoint",
|
||||
"type": "invalid_request_error",
|
||||
"code": "streaming_not_supported",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Call watsonx text generation
|
||||
watsonx_response = await watsonx_service.text_generation(
|
||||
model_id=watsonx_model,
|
||||
prompt=prompt,
|
||||
temperature=request.temperature or 1.0,
|
||||
max_tokens=request.max_tokens,
|
||||
top_p=request.top_p or 1.0,
|
||||
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
openai_response = transform_watsonx_to_openai_completion(
|
||||
watsonx_response,
|
||||
request.model,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in completion: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "internal_error",
|
||||
"code": "internal_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
114
app/routers/embeddings.py
Normal file
114
app/routers/embeddings.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Embeddings endpoint router."""
|
||||
|
||||
from typing import Union
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from app.models.openai_models import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
ErrorDetail,
|
||||
)
|
||||
from app.services.watsonx_service import watsonx_service
|
||||
from app.utils.transformers import transform_watsonx_to_openai_embeddings
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/embeddings",
|
||||
response_model=Union[EmbeddingResponse, ErrorResponse],
|
||||
responses={
|
||||
200: {"model": EmbeddingResponse},
|
||||
400: {"model": ErrorResponse},
|
||||
401: {"model": ErrorResponse},
|
||||
500: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def create_embeddings(
|
||||
request: EmbeddingRequest,
|
||||
http_request: Request,
|
||||
):
|
||||
"""Create embeddings using OpenAI-compatible API.
|
||||
|
||||
This endpoint accepts OpenAI-formatted embedding requests and translates
|
||||
them to watsonx.ai embeddings API calls.
|
||||
"""
|
||||
try:
|
||||
# Map model name if needed
|
||||
watsonx_model = settings.map_model(request.model)
|
||||
logger.info(f"Embeddings request: {request.model} -> {watsonx_model}")
|
||||
|
||||
# Handle input (can be string or list)
|
||||
if isinstance(request.input, str):
|
||||
inputs = [request.input]
|
||||
elif isinstance(request.input, list):
|
||||
if len(request.input) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Input cannot be empty",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_input",
|
||||
}
|
||||
},
|
||||
)
|
||||
# Handle list of strings or list of token IDs
|
||||
if isinstance(request.input[0], str):
|
||||
inputs = request.input
|
||||
else:
|
||||
# Token IDs not supported yet
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Token ID input not supported",
|
||||
"type": "invalid_request_error",
|
||||
"code": "unsupported_input_type",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"message": "Invalid input type",
|
||||
"type": "invalid_request_error",
|
||||
"code": "invalid_input_type",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Call watsonx embeddings
|
||||
watsonx_response = await watsonx_service.embeddings(
|
||||
model_id=watsonx_model,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
openai_response = transform_watsonx_to_openai_embeddings(
|
||||
watsonx_response,
|
||||
request.model,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in embeddings: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "internal_error",
|
||||
"code": "internal_error",
|
||||
}
|
||||
},
|
||||
)
|
||||
120
app/routers/models.py
Normal file
120
app/routers/models.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Models endpoint router."""
|
||||
|
||||
import time
|
||||
from fastapi import APIRouter
|
||||
from app.models.openai_models import ModelsResponse, ModelInfo
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Predefined list of available models
|
||||
# This can be extended or made dynamic based on watsonx.ai API
|
||||
AVAILABLE_MODELS = [
|
||||
# Granite Models
|
||||
"ibm/granite-3-1-8b-base",
|
||||
"ibm/granite-3-2-8b-instruct",
|
||||
"ibm/granite-3-3-8b-instruct",
|
||||
"ibm/granite-3-8b-instruct",
|
||||
"ibm/granite-4-h-small",
|
||||
"ibm/granite-8b-code-instruct",
|
||||
|
||||
# Llama Models
|
||||
"meta-llama/llama-3-1-70b-gptq",
|
||||
"meta-llama/llama-3-1-8b",
|
||||
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||
"meta-llama/llama-3-3-70b-instruct",
|
||||
"meta-llama/llama-3-405b-instruct",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
|
||||
|
||||
# Mistral Models
|
||||
"mistral-large-2512",
|
||||
"mistralai/mistral-medium-2505",
|
||||
"mistralai/mistral-small-3-1-24b-instruct-2503",
|
||||
|
||||
# Other Models
|
||||
"openai/gpt-oss-120b",
|
||||
|
||||
# Embedding Models
|
||||
"ibm/slate-125m-english-rtrvr",
|
||||
"ibm/slate-30m-english-rtrvr",
|
||||
]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/models",
|
||||
response_model=ModelsResponse,
|
||||
)
|
||||
async def list_models():
|
||||
"""List available models in OpenAI-compatible format.
|
||||
|
||||
Returns a list of models that can be used with the API.
|
||||
Includes both the actual watsonx model IDs and any mapped names.
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
models = []
|
||||
|
||||
# Add all available watsonx models
|
||||
for model_id in AVAILABLE_MODELS:
|
||||
models.append(
|
||||
ModelInfo(
|
||||
id=model_id,
|
||||
created=created_time,
|
||||
owned_by="ibm-watsonx",
|
||||
)
|
||||
)
|
||||
|
||||
# Add mapped model names (e.g., gpt-4 -> ibm/granite-4-h-small)
|
||||
model_mapping = settings.get_model_mapping()
|
||||
for openai_name, watsonx_id in model_mapping.items():
|
||||
if watsonx_id in AVAILABLE_MODELS:
|
||||
models.append(
|
||||
ModelInfo(
|
||||
id=openai_name,
|
||||
created=created_time,
|
||||
owned_by="ibm-watsonx",
|
||||
)
|
||||
)
|
||||
|
||||
return ModelsResponse(data=models)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/models/{model_id}",
|
||||
response_model=ModelInfo,
|
||||
)
|
||||
async def retrieve_model(model_id: str):
|
||||
"""Retrieve information about a specific model.
|
||||
|
||||
Args:
|
||||
model_id: The model ID to retrieve
|
||||
|
||||
Returns:
|
||||
Model information
|
||||
"""
|
||||
# Map the model if needed
|
||||
watsonx_model = settings.map_model(model_id)
|
||||
|
||||
# Check if model exists
|
||||
if watsonx_model not in AVAILABLE_MODELS:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": {
|
||||
"message": f"Model '{model_id}' not found",
|
||||
"type": "invalid_request_error",
|
||||
"code": "model_not_found",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return ModelInfo(
|
||||
id=model_id,
|
||||
created=int(time.time()),
|
||||
owned_by="ibm-watsonx",
|
||||
)
|
||||
5
app/services/__init__.py
Normal file
5
app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Services for watsonx-openai-proxy."""
|
||||
|
||||
from app.services.watsonx_service import watsonx_service, WatsonxService
|
||||
|
||||
__all__ = ["watsonx_service", "WatsonxService"]
|
||||
316
app/services/watsonx_service.py
Normal file
316
app/services/watsonx_service.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Service for interacting with IBM watsonx.ai APIs."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import AsyncIterator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatsonxService:
|
||||
"""Service for managing watsonx.ai API interactions."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.watsonx_base_url
|
||||
self.project_id = settings.watsonx_project_id
|
||||
self.api_key = settings.ibm_cloud_api_key
|
||||
self._bearer_token: Optional[str] = None
|
||||
self._token_expiry: Optional[float] = None
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=300.0)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _refresh_token(self) -> str:
|
||||
"""Get a fresh bearer token from IBM Cloud IAM."""
|
||||
async with self._token_lock:
|
||||
# Check if token is still valid
|
||||
if self._bearer_token and self._token_expiry:
|
||||
if time.time() < self._token_expiry - 300: # 5 min buffer
|
||||
return self._bearer_token
|
||||
|
||||
logger.info("Refreshing IBM Cloud bearer token...")
|
||||
|
||||
client = await self._get_client()
|
||||
response = await client.post(
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}",
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get bearer token: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
self._bearer_token = data["access_token"]
|
||||
self._token_expiry = time.time() + data.get("expires_in", 3600)
|
||||
|
||||
logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds")
|
||||
return self._bearer_token
|
||||
|
||||
async def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers with valid bearer token."""
|
||||
token = await self._refresh_token()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Dict],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = False,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Create a chat completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
messages: List of chat messages
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
stream: Whether to stream the response
|
||||
tools: Tool/function definitions
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
# Build watsonx request
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"messages": messages,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
url = f"{self.base_url}/text/chat"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
if stream:
|
||||
params["stream"] = "true"
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Dict],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict]:
|
||||
"""Stream a chat completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
messages: List of chat messages
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
tools: Tool/function definitions
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Yields:
|
||||
Chat completion chunks
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
# Build watsonx request
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"messages": messages,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
url = f"{self.base_url}/text/chat_stream"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
import json
|
||||
yield json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def text_generation(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Generate text completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
prompt: The input prompt
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Text generation response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"input": prompt,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_new_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
url = f"{self.base_url}/text/generation"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
inputs: List[str],
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Generate embeddings using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx embedding model ID
|
||||
inputs: List of texts to embed
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Embeddings response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/text/embeddings"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
# Global service instance
|
||||
watsonx_service = WatsonxService()
|
||||
21
app/utils/__init__.py
Normal file
21
app/utils/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Utility functions for watsonx-openai-proxy."""
|
||||
|
||||
from app.utils.transformers import (
|
||||
transform_messages_to_watsonx,
|
||||
transform_tools_to_watsonx,
|
||||
transform_watsonx_to_openai_chat,
|
||||
transform_watsonx_to_openai_chat_chunk,
|
||||
transform_watsonx_to_openai_completion,
|
||||
transform_watsonx_to_openai_embeddings,
|
||||
format_sse_event,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"transform_messages_to_watsonx",
|
||||
"transform_tools_to_watsonx",
|
||||
"transform_watsonx_to_openai_chat",
|
||||
"transform_watsonx_to_openai_chat_chunk",
|
||||
"transform_watsonx_to_openai_completion",
|
||||
"transform_watsonx_to_openai_embeddings",
|
||||
"format_sse_event",
|
||||
]
|
||||
272
app/utils/transformers.py
Normal file
272
app/utils/transformers.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Utilities for transforming between OpenAI and watsonx formats."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
from app.models.openai_models import (
|
||||
ChatMessage,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionUsage,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionChunkChoice,
|
||||
ChatCompletionChunkDelta,
|
||||
CompletionChoice,
|
||||
CompletionResponse,
|
||||
EmbeddingData,
|
||||
EmbeddingUsage,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
|
||||
|
||||
def transform_messages_to_watsonx(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||
"""Transform OpenAI messages to watsonx format.
|
||||
|
||||
Args:
|
||||
messages: List of OpenAI ChatMessage objects
|
||||
|
||||
Returns:
|
||||
List of watsonx-compatible message dicts
|
||||
"""
|
||||
watsonx_messages = []
|
||||
|
||||
for msg in messages:
|
||||
watsonx_msg = {
|
||||
"role": msg.role,
|
||||
}
|
||||
|
||||
if msg.content:
|
||||
watsonx_msg["content"] = msg.content
|
||||
|
||||
if msg.name:
|
||||
watsonx_msg["name"] = msg.name
|
||||
|
||||
if msg.tool_calls:
|
||||
watsonx_msg["tool_calls"] = msg.tool_calls
|
||||
|
||||
if msg.function_call:
|
||||
watsonx_msg["function_call"] = msg.function_call
|
||||
|
||||
watsonx_messages.append(watsonx_msg)
|
||||
|
||||
return watsonx_messages
|
||||
|
||||
|
||||
def transform_tools_to_watsonx(tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
|
||||
"""Transform OpenAI tools to watsonx format.
|
||||
|
||||
Args:
|
||||
tools: List of OpenAI tool definitions
|
||||
|
||||
Returns:
|
||||
List of watsonx-compatible tool definitions
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
# watsonx uses similar format to OpenAI for tools
|
||||
return tools
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_chat(
|
||||
watsonx_response: Dict[str, Any],
|
||||
model: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Transform watsonx chat response to OpenAI format.
|
||||
|
||||
Args:
|
||||
watsonx_response: Response from watsonx chat API
|
||||
model: Model name to include in response
|
||||
request_id: Optional request ID, generates one if not provided
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible ChatCompletionResponse
|
||||
"""
|
||||
response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
|
||||
# Extract choices
|
||||
choices = []
|
||||
watsonx_choices = watsonx_response.get("choices", [])
|
||||
|
||||
for idx, choice in enumerate(watsonx_choices):
|
||||
message_data = choice.get("message", {})
|
||||
|
||||
message = ChatMessage(
|
||||
role=message_data.get("role", "assistant"),
|
||||
content=message_data.get("content"),
|
||||
tool_calls=message_data.get("tool_calls"),
|
||||
function_call=message_data.get("function_call"),
|
||||
)
|
||||
|
||||
choices.append(
|
||||
ChatCompletionChoice(
|
||||
index=idx,
|
||||
message=message,
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
)
|
||||
|
||||
# Extract usage
|
||||
usage_data = watsonx_response.get("usage", {})
|
||||
usage = ChatCompletionUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_chat_chunk(
|
||||
watsonx_chunk: Dict[str, Any],
|
||||
model: str,
|
||||
request_id: str,
|
||||
) -> ChatCompletionChunk:
|
||||
"""Transform watsonx streaming chunk to OpenAI format.
|
||||
|
||||
Args:
|
||||
watsonx_chunk: Streaming chunk from watsonx
|
||||
model: Model name to include in response
|
||||
request_id: Request ID for this stream
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible ChatCompletionChunk
|
||||
"""
|
||||
created = int(time.time())
|
||||
|
||||
# Extract choices
|
||||
choices = []
|
||||
watsonx_choices = watsonx_chunk.get("choices", [])
|
||||
|
||||
for idx, choice in enumerate(watsonx_choices):
|
||||
delta_data = choice.get("delta", {})
|
||||
|
||||
delta = ChatCompletionChunkDelta(
|
||||
role=delta_data.get("role"),
|
||||
content=delta_data.get("content"),
|
||||
tool_calls=delta_data.get("tool_calls"),
|
||||
function_call=delta_data.get("function_call"),
|
||||
)
|
||||
|
||||
choices.append(
|
||||
ChatCompletionChunkChoice(
|
||||
index=idx,
|
||||
delta=delta,
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
)
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id=request_id,
|
||||
created=created,
|
||||
model=model,
|
||||
choices=choices,
|
||||
)
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_completion(
|
||||
watsonx_response: Dict[str, Any],
|
||||
model: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> CompletionResponse:
|
||||
"""Transform watsonx text generation response to OpenAI completion format.
|
||||
|
||||
Args:
|
||||
watsonx_response: Response from watsonx text generation API
|
||||
model: Model name to include in response
|
||||
request_id: Optional request ID, generates one if not provided
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible CompletionResponse
|
||||
"""
|
||||
response_id = request_id or f"cmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
|
||||
# Extract results
|
||||
results = watsonx_response.get("results", [])
|
||||
choices = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
choices.append(
|
||||
CompletionChoice(
|
||||
text=result.get("generated_text", ""),
|
||||
index=idx,
|
||||
finish_reason=result.get("stop_reason"),
|
||||
)
|
||||
)
|
||||
|
||||
# Extract usage
|
||||
usage_data = watsonx_response.get("usage", {})
|
||||
usage = ChatCompletionUsage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("generated_tokens", 0),
|
||||
total_tokens=usage_data.get("prompt_tokens", 0) + usage_data.get("generated_tokens", 0),
|
||||
)
|
||||
|
||||
return CompletionResponse(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def transform_watsonx_to_openai_embeddings(
|
||||
watsonx_response: Dict[str, Any],
|
||||
model: str,
|
||||
) -> EmbeddingResponse:
|
||||
"""Transform watsonx embeddings response to OpenAI format.
|
||||
|
||||
Args:
|
||||
watsonx_response: Response from watsonx embeddings API
|
||||
model: Model name to include in response
|
||||
|
||||
Returns:
|
||||
OpenAI-compatible EmbeddingResponse
|
||||
"""
|
||||
# Extract results
|
||||
results = watsonx_response.get("results", [])
|
||||
data = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
embedding = result.get("embedding", [])
|
||||
data.append(
|
||||
EmbeddingData(
|
||||
embedding=embedding,
|
||||
index=idx,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate usage
|
||||
input_token_count = watsonx_response.get("input_token_count", 0)
|
||||
usage = EmbeddingUsage(
|
||||
prompt_tokens=input_token_count,
|
||||
total_tokens=input_token_count,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=data,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def format_sse_event(data: str) -> str:
|
||||
"""Format data as Server-Sent Event.
|
||||
|
||||
Args:
|
||||
data: JSON string to send
|
||||
|
||||
Returns:
|
||||
Formatted SSE string
|
||||
"""
|
||||
return f"data: {data}\n\n"
|
||||
Reference in New Issue
Block a user