158 lines
4.6 KiB
Python
158 lines
4.6 KiB
Python
"""Main FastAPI application for watsonx-openai-proxy."""
|
|
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from fastapi import FastAPI, Request, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from app.config import settings
|
|
from app.routers import chat, completions, embeddings, models
|
|
from app.services.watsonx_service import watsonx_service
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, settings.log_level.upper()),
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan events."""
|
|
# Startup
|
|
logger.info("Starting watsonx-openai-proxy...")
|
|
logger.info(f"Cluster: {settings.watsonx_cluster}")
|
|
logger.info(f"Project ID: {settings.watsonx_project_id[:8]}...")
|
|
|
|
# Initialize token
|
|
try:
|
|
await watsonx_service._refresh_token()
|
|
logger.info("Initial bearer token obtained successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to obtain initial bearer token: {e}")
|
|
raise
|
|
|
|
yield
|
|
|
|
# Shutdown
|
|
logger.info("Shutting down watsonx-openai-proxy...")
|
|
await watsonx_service.close()
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="watsonx-openai-proxy",
|
|
description="OpenAI-compatible API proxy for IBM watsonx.ai",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# Optional API key authentication middleware
|
|
@app.middleware("http")
|
|
async def authenticate(request: Request, call_next):
|
|
"""Authenticate requests if API key is configured."""
|
|
if settings.api_key:
|
|
# Skip authentication for health check
|
|
if request.url.path == "/health":
|
|
return await call_next(request)
|
|
|
|
# Check Authorization header
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={
|
|
"error": {
|
|
"message": "Missing Authorization header",
|
|
"type": "authentication_error",
|
|
"code": "missing_authorization",
|
|
}
|
|
},
|
|
)
|
|
|
|
# Validate API key
|
|
if not auth_header.startswith("Bearer "):
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={
|
|
"error": {
|
|
"message": "Invalid Authorization header format",
|
|
"type": "authentication_error",
|
|
"code": "invalid_authorization_format",
|
|
}
|
|
},
|
|
)
|
|
|
|
token = auth_header[7:] # Remove "Bearer " prefix
|
|
if token != settings.api_key:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={
|
|
"error": {
|
|
"message": "Invalid API key",
|
|
"type": "authentication_error",
|
|
"code": "invalid_api_key",
|
|
}
|
|
},
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
# Include routers
|
|
app.include_router(chat.router, tags=["Chat"])
|
|
app.include_router(completions.router, tags=["Completions"])
|
|
app.include_router(embeddings.router, tags=["Embeddings"])
|
|
app.include_router(models.router, tags=["Models"])
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"service": "watsonx-openai-proxy",
|
|
"cluster": settings.watsonx_cluster,
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API information."""
|
|
return {
|
|
"service": "watsonx-openai-proxy",
|
|
"description": "OpenAI-compatible API proxy for IBM watsonx.ai",
|
|
"version": "1.0.0",
|
|
"endpoints": {
|
|
"chat": "/v1/chat/completions",
|
|
"completions": "/v1/completions",
|
|
"embeddings": "/v1/embeddings",
|
|
"models": "/v1/models",
|
|
"health": "/health",
|
|
},
|
|
"documentation": "/docs",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"app.main:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
log_level=settings.log_level,
|
|
reload=False,
|
|
)
|