watsonx requires tool_call_id on role=tool messages. Added field to ChatMessage model and passthrough in transformer.
276 lines
7.3 KiB
Python
276 lines
7.3 KiB
Python
"""Utilities for transforming between OpenAI and watsonx formats."""
|
|
|
|
import time
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
from app.models.openai_models import (
|
|
ChatMessage,
|
|
ChatCompletionChoice,
|
|
ChatCompletionUsage,
|
|
ChatCompletionResponse,
|
|
ChatCompletionChunk,
|
|
ChatCompletionChunkChoice,
|
|
ChatCompletionChunkDelta,
|
|
CompletionChoice,
|
|
CompletionResponse,
|
|
EmbeddingData,
|
|
EmbeddingUsage,
|
|
EmbeddingResponse,
|
|
)
|
|
|
|
|
|
def transform_messages_to_watsonx(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
|
"""Transform OpenAI messages to watsonx format.
|
|
|
|
Args:
|
|
messages: List of OpenAI ChatMessage objects
|
|
|
|
Returns:
|
|
List of watsonx-compatible message dicts
|
|
"""
|
|
watsonx_messages = []
|
|
|
|
for msg in messages:
|
|
watsonx_msg = {
|
|
"role": msg.role,
|
|
}
|
|
|
|
if msg.content:
|
|
watsonx_msg["content"] = msg.content
|
|
|
|
if msg.name:
|
|
watsonx_msg["name"] = msg.name
|
|
|
|
if msg.tool_calls:
|
|
watsonx_msg["tool_calls"] = msg.tool_calls
|
|
|
|
if msg.function_call:
|
|
watsonx_msg["function_call"] = msg.function_call
|
|
|
|
if msg.tool_call_id:
|
|
watsonx_msg["tool_call_id"] = msg.tool_call_id
|
|
|
|
watsonx_messages.append(watsonx_msg)
|
|
|
|
return watsonx_messages
|
|
|
|
|
|
def transform_tools_to_watsonx(tools: Optional[List[Dict]]) -> Optional[List[Dict]]:
|
|
"""Transform OpenAI tools to watsonx format.
|
|
|
|
Args:
|
|
tools: List of OpenAI tool definitions
|
|
|
|
Returns:
|
|
List of watsonx-compatible tool definitions
|
|
"""
|
|
if not tools:
|
|
return None
|
|
|
|
# watsonx uses similar format to OpenAI for tools
|
|
return tools
|
|
|
|
|
|
def transform_watsonx_to_openai_chat(
|
|
watsonx_response: Dict[str, Any],
|
|
model: str,
|
|
request_id: Optional[str] = None,
|
|
) -> ChatCompletionResponse:
|
|
"""Transform watsonx chat response to OpenAI format.
|
|
|
|
Args:
|
|
watsonx_response: Response from watsonx chat API
|
|
model: Model name to include in response
|
|
request_id: Optional request ID, generates one if not provided
|
|
|
|
Returns:
|
|
OpenAI-compatible ChatCompletionResponse
|
|
"""
|
|
response_id = request_id or f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
created = int(time.time())
|
|
|
|
# Extract choices
|
|
choices = []
|
|
watsonx_choices = watsonx_response.get("choices", [])
|
|
|
|
for idx, choice in enumerate(watsonx_choices):
|
|
message_data = choice.get("message", {})
|
|
|
|
message = ChatMessage(
|
|
role=message_data.get("role", "assistant"),
|
|
content=message_data.get("content"),
|
|
tool_calls=message_data.get("tool_calls"),
|
|
function_call=message_data.get("function_call"),
|
|
)
|
|
|
|
choices.append(
|
|
ChatCompletionChoice(
|
|
index=idx,
|
|
message=message,
|
|
finish_reason=choice.get("finish_reason"),
|
|
)
|
|
)
|
|
|
|
# Extract usage
|
|
usage_data = watsonx_response.get("usage", {})
|
|
usage = ChatCompletionUsage(
|
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
total_tokens=usage_data.get("total_tokens", 0),
|
|
)
|
|
|
|
return ChatCompletionResponse(
|
|
id=response_id,
|
|
created=created,
|
|
model=model,
|
|
choices=choices,
|
|
usage=usage,
|
|
)
|
|
|
|
|
|
def transform_watsonx_to_openai_chat_chunk(
|
|
watsonx_chunk: Dict[str, Any],
|
|
model: str,
|
|
request_id: str,
|
|
) -> ChatCompletionChunk:
|
|
"""Transform watsonx streaming chunk to OpenAI format.
|
|
|
|
Args:
|
|
watsonx_chunk: Streaming chunk from watsonx
|
|
model: Model name to include in response
|
|
request_id: Request ID for this stream
|
|
|
|
Returns:
|
|
OpenAI-compatible ChatCompletionChunk
|
|
"""
|
|
created = int(time.time())
|
|
|
|
# Extract choices
|
|
choices = []
|
|
watsonx_choices = watsonx_chunk.get("choices", [])
|
|
|
|
for idx, choice in enumerate(watsonx_choices):
|
|
delta_data = choice.get("delta", {})
|
|
|
|
delta = ChatCompletionChunkDelta(
|
|
role=delta_data.get("role"),
|
|
content=delta_data.get("content"),
|
|
tool_calls=delta_data.get("tool_calls"),
|
|
function_call=delta_data.get("function_call"),
|
|
)
|
|
|
|
choices.append(
|
|
ChatCompletionChunkChoice(
|
|
index=idx,
|
|
delta=delta,
|
|
finish_reason=choice.get("finish_reason"),
|
|
)
|
|
)
|
|
|
|
return ChatCompletionChunk(
|
|
id=request_id,
|
|
created=created,
|
|
model=model,
|
|
choices=choices,
|
|
)
|
|
|
|
|
|
def transform_watsonx_to_openai_completion(
|
|
watsonx_response: Dict[str, Any],
|
|
model: str,
|
|
request_id: Optional[str] = None,
|
|
) -> CompletionResponse:
|
|
"""Transform watsonx text generation response to OpenAI completion format.
|
|
|
|
Args:
|
|
watsonx_response: Response from watsonx text generation API
|
|
model: Model name to include in response
|
|
request_id: Optional request ID, generates one if not provided
|
|
|
|
Returns:
|
|
OpenAI-compatible CompletionResponse
|
|
"""
|
|
response_id = request_id or f"cmpl-{uuid.uuid4().hex[:24]}"
|
|
created = int(time.time())
|
|
|
|
# Extract results
|
|
results = watsonx_response.get("results", [])
|
|
choices = []
|
|
|
|
for idx, result in enumerate(results):
|
|
choices.append(
|
|
CompletionChoice(
|
|
text=result.get("generated_text", ""),
|
|
index=idx,
|
|
finish_reason=result.get("stop_reason"),
|
|
)
|
|
)
|
|
|
|
# Extract usage
|
|
usage_data = watsonx_response.get("usage", {})
|
|
usage = ChatCompletionUsage(
|
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
completion_tokens=usage_data.get("generated_tokens", 0),
|
|
total_tokens=usage_data.get("prompt_tokens", 0) + usage_data.get("generated_tokens", 0),
|
|
)
|
|
|
|
return CompletionResponse(
|
|
id=response_id,
|
|
created=created,
|
|
model=model,
|
|
choices=choices,
|
|
usage=usage,
|
|
)
|
|
|
|
|
|
def transform_watsonx_to_openai_embeddings(
|
|
watsonx_response: Dict[str, Any],
|
|
model: str,
|
|
) -> EmbeddingResponse:
|
|
"""Transform watsonx embeddings response to OpenAI format.
|
|
|
|
Args:
|
|
watsonx_response: Response from watsonx embeddings API
|
|
model: Model name to include in response
|
|
|
|
Returns:
|
|
OpenAI-compatible EmbeddingResponse
|
|
"""
|
|
# Extract results
|
|
results = watsonx_response.get("results", [])
|
|
data = []
|
|
|
|
for idx, result in enumerate(results):
|
|
embedding = result.get("embedding", [])
|
|
data.append(
|
|
EmbeddingData(
|
|
embedding=embedding,
|
|
index=idx,
|
|
)
|
|
)
|
|
|
|
# Calculate usage
|
|
input_token_count = watsonx_response.get("input_token_count", 0)
|
|
usage = EmbeddingUsage(
|
|
prompt_tokens=input_token_count,
|
|
total_tokens=input_token_count,
|
|
)
|
|
|
|
return EmbeddingResponse(
|
|
data=data,
|
|
model=model,
|
|
usage=usage,
|
|
)
|
|
|
|
|
|
def format_sse_event(data: str) -> str:
|
|
"""Format data as Server-Sent Event.
|
|
|
|
Args:
|
|
data: JSON string to send
|
|
|
|
Returns:
|
|
Formatted SSE string
|
|
"""
|
|
return f"data: {data}\n\n"
|