"""Chat completions endpoint router.""" import json import uuid from typing import Union from fastapi import APIRouter, HTTPException, Request from fastapi.responses import StreamingResponse from app.models.openai_models import ( ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, ErrorDetail, ) from app.services.watsonx_service import watsonx_service from app.utils.transformers import ( transform_messages_to_watsonx, transform_tools_to_watsonx, transform_watsonx_to_openai_chat, transform_watsonx_to_openai_chat_chunk, format_sse_event, ) from app.config import settings import logging logger = logging.getLogger(__name__) router = APIRouter() def normalize_messages_for_vllm(messages: list) -> list: """Normalize OpenAI message format for vLLM compatibility. vLLM is stricter than OpenAI API and requires: 1. Message content as string (not array of content parts) 2. Role must be system/user/assistant/function/tool (not "developer") Args: messages: List of message dicts Returns: Normalized list of messages """ normalized = [] for msg in messages: normalized_msg = msg.copy() # Normalize "developer" role to "system" if normalized_msg.get("role") == "developer": normalized_msg["role"] = "system" logger.debug("Normalized 'developer' role to 'system'") # Normalize array content to string for text-only messages content = normalized_msg.get("content") if isinstance(content, list): # Check if all parts are text-only if all(isinstance(p, dict) and p.get("type") == "text" for p in content): # Flatten to concatenated string normalized_msg["content"] = "\n".join(p.get("text", "") for p in content) logger.debug(f"Normalized array content to string: {len(content)} parts") else: # Has image_url or other non-text types - keep as is # vLLM may reject this, but we preserve the original format logger.warning("Message contains non-text content parts, keeping array format") normalized.append(normalized_msg) return normalized @router.post( "/v1/chat/completions", response_model=Union[ChatCompletionResponse, ErrorResponse], responses={ 200: {"model": ChatCompletionResponse}, 400: {"model": ErrorResponse}, 401: {"model": ErrorResponse}, 500: {"model": ErrorResponse}, }, ) async def create_chat_completion( request: ChatCompletionRequest, http_request: Request, ): """Create a chat completion using OpenAI-compatible API. This endpoint accepts OpenAI-formatted requests and translates them to watsonx.ai API calls. """ try: # Map model name if needed watsonx_model = settings.map_model(request.model) logger.info(f"Chat completion request: {request.model} -> {watsonx_model}") # Normalize messages for vLLM compatibility (handles array content and developer role) normalized_messages = normalize_messages_for_vllm([msg.model_dump() for msg in request.messages]) # Transform normalized messages to watsonx format # Convert back to ChatMessage objects for the transformer from app.models.openai_models import ChatMessage normalized_chat_messages = [ChatMessage(**msg) for msg in normalized_messages] watsonx_messages = transform_messages_to_watsonx(normalized_chat_messages) # Transform tools if present watsonx_tools = transform_tools_to_watsonx(request.tools) # Handle streaming if request.stream: return StreamingResponse( stream_chat_completion( watsonx_model, watsonx_messages, request, watsonx_tools, ), media_type="text/event-stream", ) # Non-streaming response watsonx_response = await watsonx_service.chat_completion( model_id=watsonx_model, messages=watsonx_messages, temperature=request.temperature or 1.0, max_tokens=request.max_tokens, top_p=request.top_p or 1.0, stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None, tools=watsonx_tools, ) # Transform response openai_response = transform_watsonx_to_openai_chat( watsonx_response, request.model, ) return openai_response except Exception as e: logger.error(f"Error in chat completion: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail={ "error": { "message": str(e), "type": "internal_error", "code": "internal_error", } }, ) async def stream_chat_completion( watsonx_model: str, watsonx_messages: list, request: ChatCompletionRequest, watsonx_tools: list = None, ): """Stream chat completion responses. Args: watsonx_model: The watsonx model ID watsonx_messages: Transformed messages request: Original OpenAI request watsonx_tools: Transformed tools Yields: Server-Sent Events with chat completion chunks """ request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" try: async for chunk in watsonx_service.chat_completion_stream( model_id=watsonx_model, messages=watsonx_messages, temperature=request.temperature or 1.0, max_tokens=request.max_tokens, top_p=request.top_p or 1.0, stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None, tools=watsonx_tools, ): # Transform chunk to OpenAI format openai_chunk = transform_watsonx_to_openai_chat_chunk( chunk, request.model, request_id, ) # Send as SSE yield format_sse_event(openai_chunk.model_dump_json()) # Send [DONE] message yield format_sse_event("[DONE]") except Exception as e: logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True) error_response = ErrorResponse( error=ErrorDetail( message=str(e), type="internal_error", code="stream_error", ) ) yield format_sse_event(error_response.model_dump_json())