Files
Michael d924b7c45f Add vLLM message normalization for OpenAI content format compatibility
- Normalize 'developer' role to 'system' (vLLM doesn't support developer role)
- Flatten array content to string for text-only messages
- Preserve mixed content (text + images) as array
- Add comprehensive unit tests for normalization logic

Fixes HTTP 422 errors when clients send OpenAI multi-content format
2026-02-23 11:59:23 -05:00

204 lines
6.8 KiB
Python

"""Chat completions endpoint router."""
import json
import uuid
from typing import Union
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from app.models.openai_models import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
ErrorDetail,
)
from app.services.watsonx_service import watsonx_service
from app.utils.transformers import (
transform_messages_to_watsonx,
transform_tools_to_watsonx,
transform_watsonx_to_openai_chat,
transform_watsonx_to_openai_chat_chunk,
format_sse_event,
)
from app.config import settings
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
def normalize_messages_for_vllm(messages: list) -> list:
"""Normalize OpenAI message format for vLLM compatibility.
vLLM is stricter than OpenAI API and requires:
1. Message content as string (not array of content parts)
2. Role must be system/user/assistant/function/tool (not "developer")
Args:
messages: List of message dicts
Returns:
Normalized list of messages
"""
normalized = []
for msg in messages:
normalized_msg = msg.copy()
# Normalize "developer" role to "system"
if normalized_msg.get("role") == "developer":
normalized_msg["role"] = "system"
logger.debug("Normalized 'developer' role to 'system'")
# Normalize array content to string for text-only messages
content = normalized_msg.get("content")
if isinstance(content, list):
# Check if all parts are text-only
if all(isinstance(p, dict) and p.get("type") == "text" for p in content):
# Flatten to concatenated string
normalized_msg["content"] = "\n".join(p.get("text", "") for p in content)
logger.debug(f"Normalized array content to string: {len(content)} parts")
else:
# Has image_url or other non-text types - keep as is
# vLLM may reject this, but we preserve the original format
logger.warning("Message contains non-text content parts, keeping array format")
normalized.append(normalized_msg)
return normalized
@router.post(
"/v1/chat/completions",
response_model=Union[ChatCompletionResponse, ErrorResponse],
responses={
200: {"model": ChatCompletionResponse},
400: {"model": ErrorResponse},
401: {"model": ErrorResponse},
500: {"model": ErrorResponse},
},
)
async def create_chat_completion(
request: ChatCompletionRequest,
http_request: Request,
):
"""Create a chat completion using OpenAI-compatible API.
This endpoint accepts OpenAI-formatted requests and translates them
to watsonx.ai API calls.
"""
try:
# Map model name if needed
watsonx_model = settings.map_model(request.model)
logger.info(f"Chat completion request: {request.model} -> {watsonx_model}")
# Normalize messages for vLLM compatibility (handles array content and developer role)
normalized_messages = normalize_messages_for_vllm([msg.model_dump() for msg in request.messages])
# Transform normalized messages to watsonx format
# Convert back to ChatMessage objects for the transformer
from app.models.openai_models import ChatMessage
normalized_chat_messages = [ChatMessage(**msg) for msg in normalized_messages]
watsonx_messages = transform_messages_to_watsonx(normalized_chat_messages)
# Transform tools if present
watsonx_tools = transform_tools_to_watsonx(request.tools)
# Handle streaming
if request.stream:
return StreamingResponse(
stream_chat_completion(
watsonx_model,
watsonx_messages,
request,
watsonx_tools,
),
media_type="text/event-stream",
)
# Non-streaming response
watsonx_response = await watsonx_service.chat_completion(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
)
# Transform response
openai_response = transform_watsonx_to_openai_chat(
watsonx_response,
request.model,
)
return openai_response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": {
"message": str(e),
"type": "internal_error",
"code": "internal_error",
}
},
)
async def stream_chat_completion(
watsonx_model: str,
watsonx_messages: list,
request: ChatCompletionRequest,
watsonx_tools: list = None,
):
"""Stream chat completion responses.
Args:
watsonx_model: The watsonx model ID
watsonx_messages: Transformed messages
request: Original OpenAI request
watsonx_tools: Transformed tools
Yields:
Server-Sent Events with chat completion chunks
"""
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
try:
async for chunk in watsonx_service.chat_completion_stream(
model_id=watsonx_model,
messages=watsonx_messages,
temperature=request.temperature or 1.0,
max_tokens=request.max_tokens,
top_p=request.top_p or 1.0,
stop=request.stop if isinstance(request.stop, list) else [request.stop] if request.stop else None,
tools=watsonx_tools,
):
# Transform chunk to OpenAI format
openai_chunk = transform_watsonx_to_openai_chat_chunk(
chunk,
request.model,
request_id,
)
# Send as SSE
yield format_sse_event(openai_chunk.model_dump_json())
# Send [DONE] message
yield format_sse_event("[DONE]")
except Exception as e:
logger.error(f"Error in streaming chat completion: {str(e)}", exc_info=True)
error_response = ErrorResponse(
error=ErrorDetail(
message=str(e),
type="internal_error",
code="stream_error",
)
)
yield format_sse_event(error_response.model_dump_json())