- Normalize 'developer' role to 'system' (vLLM doesn't support developer role) - Flatten array content to string for text-only messages - Preserve mixed content (text + images) as array - Add comprehensive unit tests for normalization logic Fixes HTTP 422 errors when clients send OpenAI multi-content format
204 lines
6.8 KiB
Python
204 lines
6.8 KiB
Python
"""Chat completions endpoint router."""
|
|
|
|
import json
|
|
import uuid
|
|
from typing import Union
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from app.models.openai_models import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ErrorResponse,
|
|
ErrorDetail,
|
|
)
|
|
from app.services.watsonx_service import watsonx_service
|
|
from app.utils.transformers import (
|
|
transform_messages_to_watsonx,
|
|
transform_tools_to_watsonx,
|
|
transform_watsonx_to_openai_chat,
|
|
transform_watsonx_to_openai_chat_chunk,
|
|
format_sse_event,
|
|
)
|
|
from app.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def normalize_messages_for_vllm(messages: list) -> list:
|
|
"""Normalize OpenAI message format for vLLM compatibility.
|
|
|
|
vLLM is stricter than OpenAI API and requires:
|
|
1. Message content as string (not array of content parts)
|
|
2. Role must be system/user/assistant/function/tool (not "developer")
|
|
|
|
Args:
|
|
messages: List of message dicts
|
|
|
|
Returns:
|
|
Normalized list of messages
|
|
"""
|
|
normalized = []
|
|
|
|
for msg in messages:
|
|
normalized_msg = msg.copy()
|
|
|
|
# Normalize "developer" role to "system"
|
|
if normalized_msg.get("role") == "developer":
|
|
normalized_msg["role"] = "system"
|
|
logger.debug("Normalized 'developer' role to 'system'")
|
|
|
|
# Normalize array content to string for text-only messages
|
|
content = normalized_msg.get("content")
|
|
if isinstance(content, list):
|
|
# Check if all parts are text-only
|
|
if all(isinstance(p, dict) and p.get("type") == "text" for p in content):
|
|
# Flatten to concatenated string
|
|
normalized_msg["content"] = "\n".join(p.get("text", "") for p in content)
|
|
logger.debug(f"Normalized array content to string: {len(content)} parts")
|
|
else:
|
|
# Has image_url or other non-text types - keep as is
|
|
# vLLM may reject this, but we preserve the original format
|
|
logger.warning("Message contains non-text content parts, keeping array format")
|
|
|
|
normalized.append(normalized_msg)
|
|
|
|
return normalized
|
|
|
|
|
|
@router.post(
|
|
"/v1/chat/completions",
|
|
response_model=Union[ChatCompletionResponse, ErrorResponse],
|
|
responses={
|
|
200: {"model": ChatCompletionResponse},
|
|
400: {"model": ErrorResponse},
|
|
401: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
async def create_chat_completion(
|
|
request: ChatCompletionRequest,
|
|
http_request: Request,
|
|
):
|
|
"""Create a chat completion using OpenAI-compatible API.
|
|
|
|
This endpoint accepts OpenAI-formatted requests and translates them
|
|
to watsonx.ai API calls.
|
|
"""
|
|
try:
|
|
# Map model name if needed
|
|
watsonx_model = settings.map_model(request.model)
|
|
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
|
|
|
|
# Normalize messages for vLLM compatibility (handles array content and developer role)
|
|
normalized_messages = normalize_messages_for_vllm([msg.model_dump() for msg in request.messages])
|
|
|
|
# Transform normalized messages to watsonx format
|
|
# Convert back to ChatMessage objects for the transformer
|
|
from app.models.openai_models import ChatMessage
|
|
normalized_chat_messages = [ChatMessage(**msg) for msg in normalized_messages]
|
|
watsonx_messages = transform_messages_to_watsonx(normalized_chat_messages)
|
|
|
|
# Transform tools if present
|
|
watsonx_tools = transform_tools_to_watsonx(request.tools)
|
|
|
|
# Handle streaming
|
|
if request.stream:
|
|
return StreamingResponse(
|
|
stream_chat_completion(
|
|
watsonx_model,
|
|
watsonx_messages,
|
|
request,
|
|
watsonx_tools,
|
|
),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
# Non-streaming response
|
|
watsonx_response = await watsonx_service.chat_completion(
|
|
model_id=watsonx_model,
|
|
messages=watsonx_messages,
|
|
temperature=request.temperature or 1.0,
|
|
max_tokens=request.max_tokens,
|
|
top_p=request.top_p or 1.0,
|
|
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
|
tools=watsonx_tools,
|
|
)
|
|
|
|
# Transform response
|
|
openai_response = transform_watsonx_to_openai_chat(
|
|
watsonx_response,
|
|
request.model,
|
|
)
|
|
|
|
return openai_response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": {
|
|
"message": str(e),
|
|
"type": "internal_error",
|
|
"code": "internal_error",
|
|
}
|
|
},
|
|
)
|
|
|
|
|
|
async def stream_chat_completion(
|
|
watsonx_model: str,
|
|
watsonx_messages: list,
|
|
request: ChatCompletionRequest,
|
|
watsonx_tools: list = None,
|
|
):
|
|
"""Stream chat completion responses.
|
|
|
|
Args:
|
|
watsonx_model: The watsonx model ID
|
|
watsonx_messages: Transformed messages
|
|
request: Original OpenAI request
|
|
watsonx_tools: Transformed tools
|
|
|
|
Yields:
|
|
Server-Sent Events with chat completion chunks
|
|
"""
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
|
|
try:
|
|
async for chunk in watsonx_service.chat_completion_stream(
|
|
model_id=watsonx_model,
|
|
messages=watsonx_messages,
|
|
temperature=request.temperature or 1.0,
|
|
max_tokens=request.max_tokens,
|
|
top_p=request.top_p or 1.0,
|
|
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
|
|
tools=watsonx_tools,
|
|
):
|
|
# Transform chunk to OpenAI format
|
|
openai_chunk = transform_watsonx_to_openai_chat_chunk(
|
|
chunk,
|
|
request.model,
|
|
request_id,
|
|
)
|
|
|
|
# Send as SSE
|
|
yield format_sse_event(openai_chunk.model_dump_json())
|
|
|
|
# Send [DONE] message
|
|
yield format_sse_event("[DONE]")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
|
|
error_response = ErrorResponse(
|
|
error=ErrorDetail(
|
|
message=str(e),
|
|
type="internal_error",
|
|
code="stream_error",
|
|
)
|
|
)
|
|
yield format_sse_event(error_response.model_dump_json())
|