Files
watsonx-openai-proxy/app/services/watsonx_service.py

317 lines
9.8 KiB
Python

"""Service for interacting with IBM watsonx.ai APIs."""
import asyncio
import time
from typing import AsyncIterator, Dict, List, Optional
import httpx
from app.config import settings
import logging
logger = logging.getLogger(__name__)
class WatsonxService:
"""Service for managing watsonx.ai API interactions."""
def __init__(self):
self.base_url = settings.watsonx_base_url
self.project_id = settings.watsonx_project_id
self.api_key = settings.ibm_cloud_api_key
self._bearer_token: Optional[str] = None
self._token_expiry: Optional[float] = None
self._token_lock = asyncio.Lock()
self._client: Optional[httpx.AsyncClient] = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
self._client = httpx.AsyncClient(timeout=300.0)
return self._client
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _refresh_token(self) -> str:
"""Get a fresh bearer token from IBM Cloud IAM."""
async with self._token_lock:
# Check if token is still valid
if self._bearer_token and self._token_expiry:
if time.time() < self._token_expiry - 300: # 5 min buffer
return self._bearer_token
logger.info("Refreshing IBM Cloud bearer token...")
client = await self._get_client()
response = await client.post(
"https://iam.cloud.ibm.com/identity/token",
headers={"Content-Type": "application/x-www-form-urlencoded"},
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}",
)
if response.status_code != 200:
raise Exception(f"Failed to get bearer token: {response.text}")
data = response.json()
self._bearer_token = data["access_token"]
self._token_expiry = time.time() + data.get("expires_in", 3600)
logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds")
return self._bearer_token
async def _get_headers(self) -> Dict[str, str]:
"""Get headers with valid bearer token."""
token = await self._refresh_token()
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
async def chat_completion(
self,
model_id: str,
messages: List[Dict],
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
stream: bool = False,
tools: Optional[List[Dict]] = None,
**kwargs,
) -> Dict:
"""Create a chat completion using watsonx.ai.
Args:
model_id: The watsonx model ID
messages: List of chat messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
stream: Whether to stream the response
tools: Tool/function definitions
**kwargs: Additional parameters
Returns:
Chat completion response
"""
headers = await self._get_headers()
client = await self._get_client()
# Build watsonx request
payload = {
"model_id": model_id,
"project_id": self.project_id,
"messages": messages,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
if tools:
payload["tools"] = tools
url = f"{self.base_url}/text/chat"
params = {"version": "2024-02-13"}
if stream:
params["stream"] = "true"
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
async def chat_completion_stream(
self,
model_id: str,
messages: List[Dict],
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
tools: Optional[List[Dict]] = None,
**kwargs,
) -> AsyncIterator[Dict]:
"""Stream a chat completion using watsonx.ai.
Args:
model_id: The watsonx model ID
messages: List of chat messages
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
tools: Tool/function definitions
**kwargs: Additional parameters
Yields:
Chat completion chunks
"""
headers = await self._get_headers()
client = await self._get_client()
# Build watsonx request
payload = {
"model_id": model_id,
"project_id": self.project_id,
"messages": messages,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
if tools:
payload["tools"] = tools
url = f"{self.base_url}/text/chat_stream"
params = {"version": "2024-02-13"}
async with client.stream(
"POST",
url,
headers=headers,
json=payload,
params=params,
) as response:
if response.status_code != 200:
text = await response.aread()
raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}")
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
if data.strip() == "[DONE]":
break
try:
import json
yield json.loads(data)
except json.JSONDecodeError:
continue
async def text_generation(
self,
model_id: str,
prompt: str,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
top_p: float = 1.0,
stop: Optional[List[str]] = None,
**kwargs,
) -> Dict:
"""Generate text completion using watsonx.ai.
Args:
model_id: The watsonx model ID
prompt: The input prompt
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
stop: Stop sequences
**kwargs: Additional parameters
Returns:
Text generation response
"""
headers = await self._get_headers()
client = await self._get_client()
payload = {
"model_id": model_id,
"project_id": self.project_id,
"input": prompt,
"parameters": {
"temperature": temperature,
"top_p": top_p,
},
}
if max_tokens:
payload["parameters"]["max_new_tokens"] = max_tokens
if stop:
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
url = f"{self.base_url}/text/generation"
params = {"version": "2024-02-13"}
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
async def embeddings(
self,
model_id: str,
inputs: List[str],
**kwargs,
) -> Dict:
"""Generate embeddings using watsonx.ai.
Args:
model_id: The watsonx embedding model ID
inputs: List of texts to embed
**kwargs: Additional parameters
Returns:
Embeddings response
"""
headers = await self._get_headers()
client = await self._get_client()
payload = {
"model_id": model_id,
"project_id": self.project_id,
"inputs": inputs,
}
url = f"{self.base_url}/text/embeddings"
params = {"version": "2024-02-13"}
response = await client.post(
url,
headers=headers,
json=payload,
params=params,
)
if response.status_code != 200:
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
return response.json()
# Global service instance
watsonx_service = WatsonxService()