Files
watsonx-openai-proxy/app/routers/chat.py

157 lines
4.8 KiB
Python

"""Chat completions endpoint router."""
import json
import uuid
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from app.models.openai_models import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import (
transform_messages_to_watsonx,
transform_tools_to_watsonx,
transform_watsonx_to_openai_chat,
transform_watsonx_to_openai_chat_chunk,
format_sse_event,
)
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post(
"/v1/chat/completions",
response_model=Union[ChatCompletionResponse, ErrorResponse],
responses={
200: {"model": ChatCompletionResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_chat_completion(
request: ChatCompletionRequest,
http_request: Request,
):
"""Create a chat completion using OpenAI-compatible API.
This endpoint accepts OpenAI-formatted requests and translates them
to watsonx.ai API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
# Transform messages
watsonx_messages = transform_messages_to_watsonx(request.messages)
# Transform tools if present
watsonx_tools = transform_tools_to_watsonx(request.tools)
# Handle streaming
if request.stream:
return StreamingResponse(
stream_chat_completion(
watsonx_model,
watsonx_messages,
request,
watsonx_tools,
),
media_type="text/event-stream",
)
# Non-streaming response
watsonx_response = await watsonx_service.chat_completion(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
)
# Transform response
openai_response = transform_watsonx_to_openai_chat(
watsonx_response,
request.model,
)
return openai_response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)
async def stream_chat_completion(
watsonx_model: str,
watsonx_messages: list,
request: ChatCompletionRequest,
watsonx_tools: list = None,
):
"""Stream chat completion responses.
Args:
watsonx_model: The watsonx model ID
watsonx_messages: Transformed messages
request: Original OpenAI request
watsonx_tools: Transformed tools
Yields:
Server-Sent Events with chat completion chunks
"""
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
try:
async for chunk in watsonx_service.chat_completion_stream(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
):
# Transform chunk to OpenAI format
openai_chunk = transform_watsonx_to_openai_chat_chunk(
chunk,
request.model,
request_id,
)
# Send as SSE
yield format_sse_event(openai_chunk.model_dump_json())
# Send [DONE] message
yield format_sse_event("[DONE]")
except Exception as e:
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
error_response = ErrorResponse(
error=ErrorDetail(
message=str(e),
type="internal_error",
code="stream_error",
)
)
yield format_sse_event(error_response.model_dump_json())