"""Chat completions endpoint router.""" import json import uuid from typing import Union from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse from app.models.openai_models import ( ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ErrorDetail, ) from app.services.watsonx_service import watsonx_service from app.utils.transformers import ( transform_messages_to_watsonx, transform_tools_to_watsonx, transform_watsonx_to_openai_chat, transform_watsonx_to_openai_chat_chunk, format_sse_event, ) from app.config import settings import logging logger = logging.getLogger(__name__) router = APIRouter() @router.post( "/v1/chat/completions", response_model=Union[ChatCompletionResponse, ErrorResponse], responses={ 200: {"model": ChatCompletionResponse}, 400: {"model": ErrorResponse}, 401: {"model": ErrorResponse}, 500: {"model": ErrorResponse}, }, ) async def create_chat_completion( request: ChatCompletionRequest, http_request: Request, ): """Create a chat completion using OpenAI-compatible API. This endpoint accepts OpenAI-formatted requests and translates them to watsonx.ai API calls. """ try: # Map model name if needed watsonx_model = settings.map_model(request.model) logger.info(f"Chat completion request: {request.model} -> {watsonx_model}") # Transform messages watsonx_messages = transform_messages_to_watsonx(request.messages) # Transform tools if present watsonx_tools = transform_tools_to_watsonx(request.tools) # Handle streaming if request.stream: return StreamingResponse( stream_chat_completion( watsonx_model, watsonx_messages, request, watsonx_tools, ), media_type="text/event-stream", ) # Non-streaming response watsonx_response = await watsonx_service.chat_completion( model_id=watsonx_model, messages=watsonx_messages, temperature=request.temperature or 1.0, max_tokens=request.max_tokens, top_p=request.top_p or 1.0, stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None, tools=watsonx_tools, ) # Transform response openai_response = transform_watsonx_to_openai_chat( watsonx_response, request.model, ) return openai_response except Exception as e: logger.error(f"Error in chat completion: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail={ "error": { "message": str(e), "type": "internal_error", "code": "internal_error", } }, ) async def stream_chat_completion( watsonx_model: str, watsonx_messages: list, request: ChatCompletionRequest, watsonx_tools: list = None, ): """Stream chat completion responses. Args: watsonx_model: The watsonx model ID watsonx_messages: Transformed messages request: Original OpenAI request watsonx_tools: Transformed tools Yields: Server-Sent Events with chat completion chunks """ request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" try: async for chunk in watsonx_service.chat_completion_stream( model_id=watsonx_model, messages=watsonx_messages, temperature=request.temperature or 1.0, max_tokens=request.max_tokens, top_p=request.top_p or 1.0, stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None, tools=watsonx_tools, ): # Transform chunk to OpenAI format openai_chunk = transform_watsonx_to_openai_chat_chunk( chunk, request.model, request_id, ) # Send as SSE yield format_sse_event(openai_chunk.model_dump_json()) # Send [DONE] message yield format_sse_event("[DONE]") except Exception as e: logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True) error_response = ErrorResponse( error=ErrorDetail( message=str(e), type="internal_error", code="stream_error", ) ) yield format_sse_event(error_response.model_dump_json())