Add AGENTS.md documentation for AI agent guidance
This commit is contained in:
5
app/services/__init__.py
Normal file
5
app/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Services for watsonx-openai-proxy."""
|
||||
|
||||
from app.services.watsonx_service import watsonx_service, WatsonxService
|
||||
|
||||
__all__ = ["watsonx_service", "WatsonxService"]
|
||||
316
app/services/watsonx_service.py
Normal file
316
app/services/watsonx_service.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Service for interacting with IBM watsonx.ai APIs."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import AsyncIterator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.config import settings
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatsonxService:
|
||||
"""Service for managing watsonx.ai API interactions."""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.watsonx_base_url
|
||||
self.project_id = settings.watsonx_project_id
|
||||
self.api_key = settings.ibm_cloud_api_key
|
||||
self._bearer_token: Optional[str] = None
|
||||
self._token_expiry: Optional[float] = None
|
||||
self._token_lock = asyncio.Lock()
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=300.0)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _refresh_token(self) -> str:
|
||||
"""Get a fresh bearer token from IBM Cloud IAM."""
|
||||
async with self._token_lock:
|
||||
# Check if token is still valid
|
||||
if self._bearer_token and self._token_expiry:
|
||||
if time.time() < self._token_expiry - 300: # 5 min buffer
|
||||
return self._bearer_token
|
||||
|
||||
logger.info("Refreshing IBM Cloud bearer token...")
|
||||
|
||||
client = await self._get_client()
|
||||
response = await client.post(
|
||||
"https://iam.cloud.ibm.com/identity/token",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={self.api_key}",
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get bearer token: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
self._bearer_token = data["access_token"]
|
||||
self._token_expiry = time.time() + data.get("expires_in", 3600)
|
||||
|
||||
logger.info(f"Bearer token refreshed. Expires in {data.get('expires_in', 3600)} seconds")
|
||||
return self._bearer_token
|
||||
|
||||
async def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers with valid bearer token."""
|
||||
token = await self._refresh_token()
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Dict],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
stream: bool = False,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Create a chat completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
messages: List of chat messages
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
stream: Whether to stream the response
|
||||
tools: Tool/function definitions
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Chat completion response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
# Build watsonx request
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"messages": messages,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
url = f"{self.base_url}/text/chat"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
if stream:
|
||||
params["stream"] = "true"
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Dict],
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict]:
|
||||
"""Stream a chat completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
messages: List of chat messages
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
tools: Tool/function definitions
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Yields:
|
||||
Chat completion chunks
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
# Build watsonx request
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"messages": messages,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
url = f"{self.base_url}/text/chat_stream"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {text.decode()}")
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
import json
|
||||
yield json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def text_generation(
|
||||
self,
|
||||
model_id: str,
|
||||
prompt: str,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
top_p: float = 1.0,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Generate text completion using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx model ID
|
||||
prompt: The input prompt
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
stop: Stop sequences
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Text generation response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"input": prompt,
|
||||
"parameters": {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["parameters"]["max_new_tokens"] = max_tokens
|
||||
|
||||
if stop:
|
||||
payload["parameters"]["stop_sequences"] = stop if isinstance(stop, list) else [stop]
|
||||
|
||||
url = f"{self.base_url}/text/generation"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
inputs: List[str],
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Generate embeddings using watsonx.ai.
|
||||
|
||||
Args:
|
||||
model_id: The watsonx embedding model ID
|
||||
inputs: List of texts to embed
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Embeddings response
|
||||
"""
|
||||
headers = await self._get_headers()
|
||||
client = await self._get_client()
|
||||
|
||||
payload = {
|
||||
"model_id": model_id,
|
||||
"project_id": self.project_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/text/embeddings"
|
||||
params = {"version": "2024-02-13"}
|
||||
|
||||
response = await client.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"watsonx API error: {response.status_code} - {response.text}")
|
||||
|
||||
return response.json()
|
||||
|
||||
|
||||
# Global service instance
|
||||
watsonx_service = WatsonxService()
|
||||
Reference in New Issue
Block a user