92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
"""Configuration management for watsonx-openai-proxy."""
|
|
|
|
import os
|
|
from typing import Dict, Optional
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""Application settings loaded from environment variables."""
|
|
|
|
# IBM Cloud Configuration
|
|
ibm_cloud_api_key: str
|
|
watsonx_project_id: str
|
|
watsonx_cluster: str = "us-south"
|
|
|
|
# Server Configuration
|
|
host: str = "0.0.0.0"
|
|
port: int = 8000
|
|
log_level: str = "info"
|
|
|
|
# API Configuration
|
|
api_key: Optional[str] = None
|
|
allowed_origins: str = "*"
|
|
|
|
# Token Management
|
|
token_refresh_interval: int = 3000 # 50 minutes in seconds
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_file=".env",
|
|
env_file_encoding="utf-8",
|
|
case_sensitive=False,
|
|
extra="allow", # Allow extra fields for model mapping
|
|
)
|
|
|
|
@property
|
|
def watsonx_base_url(self) -> str:
|
|
"""Construct the watsonx.ai base URL from cluster."""
|
|
return f"https://{self.watsonx_cluster}.ml.cloud.ibm.com/ml/v1"
|
|
|
|
@property
|
|
def cors_origins(self) -> list[str]:
|
|
"""Parse CORS origins from comma-separated string."""
|
|
if self.allowed_origins == "*":
|
|
return ["*"]
|
|
return [origin.strip() for origin in self.allowed_origins.split(",")]
|
|
|
|
def get_model_mapping(self) -> Dict[str, str]:
|
|
"""Extract model mappings from environment variables.
|
|
|
|
Looks for variables like MODEL_MAP_GPT4=ibm/granite-4-h-small
|
|
and creates a mapping dict.
|
|
"""
|
|
import re
|
|
mapping = {}
|
|
|
|
# Check os.environ first
|
|
for key, value in os.environ.items():
|
|
if key.startswith("MODEL_MAP_"):
|
|
model_name = key.replace("MODEL_MAP_", "")
|
|
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
|
openai_model = model_name.lower().replace("_", "-")
|
|
mapping[openai_model] = value
|
|
|
|
# Also check pydantic's extra fields (from .env file)
|
|
# These come in as lowercase: model_map_gpt4 instead of MODEL_MAP_GPT4
|
|
extra = getattr(self, '__pydantic_extra__', {}) or {}
|
|
for key, value in extra.items():
|
|
if key.startswith("model_map_"):
|
|
# Convert back to uppercase for processing
|
|
model_name = key.replace("model_map_", "").upper()
|
|
model_name = re.sub(r'([A-Z]+)(\d+)', r'\1-\2', model_name)
|
|
openai_model = model_name.lower().replace("_", "-")
|
|
mapping[openai_model] = value
|
|
|
|
return mapping
|
|
|
|
def map_model(self, openai_model: str) -> str:
|
|
"""Map an OpenAI model name to a watsonx model ID.
|
|
|
|
Args:
|
|
openai_model: OpenAI model name (e.g., "gpt-4", "gpt-3.5-turbo")
|
|
|
|
Returns:
|
|
Corresponding watsonx model ID, or the original name if no mapping exists
|
|
"""
|
|
model_map = self.get_model_mapping()
|
|
return model_map.get(openai_model, openai_model)
|
|
|
|
|
|
# Global settings instance
|
|
settings = Settings()
|