115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
"""Embeddings endpoint router."""
|
|
|
|
from typing import Union
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from app.models.openai_models import (
|
|
EmbeddingRequest,
|
|
EmbeddingResponse,
|
|
ErrorResponse,
|
|
ErrorDetail,
|
|
)
|
|
from app.services.watsonx_service import watsonx_service
|
|
from app.utils.transformers import transform_watsonx_to_openai_embeddings
|
|
from app.config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post(
|
|
"/v1/embeddings",
|
|
response_model=Union[EmbeddingResponse, ErrorResponse],
|
|
responses={
|
|
200: {"model": EmbeddingResponse},
|
|
400: {"model": ErrorResponse},
|
|
401: {"model": ErrorResponse},
|
|
500: {"model": ErrorResponse},
|
|
},
|
|
)
|
|
async def create_embeddings(
|
|
request: EmbeddingRequest,
|
|
http_request: Request,
|
|
):
|
|
"""Create embeddings using OpenAI-compatible API.
|
|
|
|
This endpoint accepts OpenAI-formatted embedding requests and translates
|
|
them to watsonx.ai embeddings API calls.
|
|
"""
|
|
try:
|
|
# Map model name if needed
|
|
watsonx_model = settings.map_model(request.model)
|
|
logger.info(f"Embeddings request: {request.model} -> {watsonx_model}")
|
|
|
|
# Handle input (can be string or list)
|
|
if isinstance(request.input, str):
|
|
inputs = [request.input]
|
|
elif isinstance(request.input, list):
|
|
if len(request.input) == 0:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": {
|
|
"message": "Input cannot be empty",
|
|
"type": "invalid_request_error",
|
|
"code": "invalid_input",
|
|
}
|
|
},
|
|
)
|
|
# Handle list of strings or list of token IDs
|
|
if isinstance(request.input[0], str):
|
|
inputs = request.input
|
|
else:
|
|
# Token IDs not supported yet
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": {
|
|
"message": "Token ID input not supported",
|
|
"type": "invalid_request_error",
|
|
"code": "unsupported_input_type",
|
|
}
|
|
},
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": {
|
|
"message": "Invalid input type",
|
|
"type": "invalid_request_error",
|
|
"code": "invalid_input_type",
|
|
}
|
|
},
|
|
)
|
|
|
|
# Call watsonx embeddings
|
|
watsonx_response = await watsonx_service.embeddings(
|
|
model_id=watsonx_model,
|
|
inputs=inputs,
|
|
)
|
|
|
|
# Transform response
|
|
openai_response = transform_watsonx_to_openai_embeddings(
|
|
watsonx_response,
|
|
request.model,
|
|
)
|
|
|
|
return openai_response
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error in embeddings: {str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": {
|
|
"message": str(e),
|
|
"type": "internal_error",
|
|
"code": "internal_error",
|
|
}
|
|
},
|
|
)
|