"""Service for interacting with IBM watsonx.ai APIs.""" import asyncio import time from typing import AsyncIterator, Dict, List, Optional import httpx from app.config import settings import logging logger = logging.getLogger(__name__) class WatsonxService: """Service for managing watsonx.ai API interactions.""" def __init__(self): self.base_url = settings.watsonx_base_url self.project_id = settings.watsonx_project_id self.api_key = settings.ibm_cloud_api_key self._bearer_token: Optional[str] = None self._token_expiry: Optional[float] = None self._token_lock = asyncio.Lock() self._client: Optional[httpx.AsyncClient] = None async def _get_client(self) -> httpx.AsyncClient: """Get or create HTTP client.""" if self._client is None: self._client = httpx.AsyncClient(timeout=300.0) return self._client async def close(self): """Close the HTTP client.""" if self._client: await self._client.aclose() self._client = None async def _refresh_token(self) -> str: """Get a fresh bearer token from IBM Cloud IAM.""" async with self._token_lock: # Check if token is still valid if self._bearer_token and self._token_expiry: if time.time() < self._token_expiry - 300: # 5 min buffer return self._bearer_token logger.info("Refreshing IBM Cloud bearer token...") client = await self._get_client() response = await client.post( "https://iam.cloud.ibm.com/identity/token", headers={"Content-Type": "application/x-www-form-urlencoded"}, data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}", ) if response.status_code != 200: raise Exception(f"Failed to get bearer token: {response.text}") data = response.json() self._bearer_token = data["access_token"] self._token_expiry = time.time() + data.get("expires_in", 3600) logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds") return self._bearer_token async def _get_headers(self) -> Dict[str, str]: """Get headers with valid bearer token.""" token = await self._refresh_token() return { "Authorization": f"Bearer {token}", "Content-Type": "application/json", "Accept": "application/json", } async def chat_completion( self, model_id: str, messages: List[Dict], temperature: float = 1.0, max_tokens: Optional[int] = None, top_p: float = 1.0, stop: Optional[List[str]] = None, stream: bool = False, tools: Optional[List[Dict]] = None, **kwargs, ) -> Dict: """Create a chat completion using watsonx.ai. Args: model_id: The watsonx model ID messages: List of chat messages temperature: Sampling temperature max_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter stop: Stop sequences stream: Whether to stream the response tools: Tool/function definitions **kwargs: Additional parameters Returns: Chat completion response """ headers = await self._get_headers() client = await self._get_client() # Build watsonx request payload = { "model_id": model_id, "project_id": self.project_id, "messages": messages, "parameters": { "temperature": temperature, "top_p": top_p, }, } if max_tokens: payload["parameters"]["max_tokens"] = max_tokens if stop: payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop] if tools: payload["tools"] = tools url = f"{self.base_url}/text/chat" params = {"version": "2024-02-13"} if stream: params["stream"] = "true" response = await client.post( url, headers=headers, json=payload, params=params, ) if response.status_code != 200: raise Exception(f"watsonx API error: {response.status_code} - {response.text}") return response.json() async def chat_completion_stream( self, model_id: str, messages: List[Dict], temperature: float = 1.0, max_tokens: Optional[int] = None, top_p: float = 1.0, stop: Optional[List[str]] = None, tools: Optional[List[Dict]] = None, **kwargs, ) -> AsyncIterator[Dict]: """Stream a chat completion using watsonx.ai. Args: model_id: The watsonx model ID messages: List of chat messages temperature: Sampling temperature max_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter stop: Stop sequences tools: Tool/function definitions **kwargs: Additional parameters Yields: Chat completion chunks """ headers = await self._get_headers() client = await self._get_client() # Build watsonx request payload = { "model_id": model_id, "project_id": self.project_id, "messages": messages, "parameters": { "temperature": temperature, "top_p": top_p, }, } if max_tokens: payload["parameters"]["max_tokens"] = max_tokens if stop: payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop] if tools: payload["tools"] = tools url = f"{self.base_url}/text/chat_stream" params = {"version": "2024-02-13"} async with client.stream( "POST", url, headers=headers, json=payload, params=params, ) as response: if response.status_code != 200: text = await response.aread() raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}") async for line in response.aiter_lines(): if line.startswith("data: "): data = line[6:] # Remove "data: " prefix if data.strip() == "[DONE]": break try: import json yield json.loads(data) except json.JSONDecodeError: continue async def text_generation( self, model_id: str, prompt: str, temperature: float = 1.0, max_tokens: Optional[int] = None, top_p: float = 1.0, stop: Optional[List[str]] = None, **kwargs, ) -> Dict: """Generate text completion using watsonx.ai. Args: model_id: The watsonx model ID prompt: The input prompt temperature: Sampling temperature max_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter stop: Stop sequences **kwargs: Additional parameters Returns: Text generation response """ headers = await self._get_headers() client = await self._get_client() payload = { "model_id": model_id, "project_id": self.project_id, "input": prompt, "parameters": { "temperature": temperature, "top_p": top_p, }, } if max_tokens: payload["parameters"]["max_new_tokens"] = max_tokens if stop: payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop] url = f"{self.base_url}/text/generation" params = {"version": "2024-02-13"} response = await client.post( url, headers=headers, json=payload, params=params, ) if response.status_code != 200: raise Exception(f"watsonx API error: {response.status_code} - {response.text}") return response.json() async def embeddings( self, model_id: str, inputs: List[str], **kwargs, ) -> Dict: """Generate embeddings using watsonx.ai. Args: model_id: The watsonx embedding model ID inputs: List of texts to embed **kwargs: Additional parameters Returns: Embeddings response """ headers = await self._get_headers() client = await self._get_client() payload = { "model_id": model_id, "project_id": self.project_id, "inputs": inputs, } url = f"{self.base_url}/text/embeddings" params = {"version": "2024-02-13"} response = await client.post( url, headers=headers, json=payload, params=params, ) if response.status_code != 200: raise Exception(f"watsonx API error: {response.status_code} - {response.text}") return response.json() # Global service instance watsonx_service = WatsonxService()