Files
watsonx-openai-proxy/app/main.py

158 lines
4.6 KiB
Python

"""Main FastAPI application for watsonx-openai-proxy."""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.config import settings
from app.routers import chat, completions, embeddings, models
from app.services.watsonx_service import watsonx_service
# Configure logging
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan events."""
# Startup
logger.info("Starting watsonx-openai-proxy...")
logger.info(f"Cluster: {settings.watsonx_cluster}")
logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...")
# Initialize token
try:
await watsonx_service._refresh_token()
logger.info("Initial bearer token obtained successfully")
except Exception as e:
logger.error(f"Failed to obtain initial bearer token: {e}")
raise
yield
# Shutdown
logger.info("Shutting down watsonx-openai-proxy...")
await watsonx_service.close()
# Create FastAPI app
app = FastAPI(
title="watsonx-openai-proxy",
description="OpenAI-compatible API proxy for IBM watsonx.ai",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Optional API key authentication middleware
@app.middleware("http")
async def authenticate(request: Request, call_next):
"""Authenticate requests if API key is configured."""
if settings.api_key:
# Skip authentication for health check
if request.url.path == "/health":
return await call_next(request)
# Check Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Missing Authorization header",
"type": "authentication_error",
"code": "missing_authorization",
}
},
)
# Validate API key
if not auth_header.startswith("Bearer "):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Invalid Authorization header format",
"type": "authentication_error",
"code": "invalid_authorization_format",
}
},
)
token = auth_header[7:] # Remove "Bearer " prefix
if token != settings.api_key:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"error": {
"message": "Invalid API key",
"type": "authentication_error",
"code": "invalid_api_key",
}
},
)
return await call_next(request)
# Include routers
app.include_router(chat.router, tags=["Chat"])
app.include_router(completions.router, tags=["Completions"])
app.include_router(embeddings.router, tags=["Embeddings"])
app.include_router(models.router, tags=["Models"])
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"service": "watsonx-openai-proxy",
"cluster": settings.watsonx_cluster,
}
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"service": "watsonx-openai-proxy",
"description": "OpenAI-compatible API proxy for IBM watsonx.ai",
"version": "1.0.0",
"endpoints": {
"chat": "/v1/chat/completions",
"completions": "/v1/completions",
"embeddings": "/v1/embeddings",
"models": "/v1/models",
"health": "/health",
},
"documentation": "/docs",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
log_level=settings.log_level,
reload=False,
)