88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
"""Basic tests for watsonx-openai-proxy."""
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from app.main import app
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
def test_health_check():
|
|
"""Test the health check endpoint."""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert "cluster" in data
|
|
|
|
|
|
def test_root_endpoint():
|
|
"""Test the root endpoint."""
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["service"] == "watsonx-openai-proxy"
|
|
assert "endpoints" in data
|
|
|
|
|
|
def test_list_models():
|
|
"""Test listing available models."""
|
|
response = client.get("/v1/models")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["object"] == "list"
|
|
assert len(data["data"]) > 0
|
|
assert all(model["object"] == "model" for model in data["data"])
|
|
|
|
|
|
def test_retrieve_model():
|
|
"""Test retrieving a specific model."""
|
|
response = client.get("/v1/models/ibm/granite-3-8b-instruct")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == "ibm/granite-3-8b-instruct"
|
|
assert data["object"] == "model"
|
|
|
|
|
|
def test_retrieve_nonexistent_model():
|
|
"""Test retrieving a model that doesn't exist."""
|
|
response = client.get("/v1/models/nonexistent-model")
|
|
assert response.status_code == 404
|
|
|
|
|
|
# Note: The following tests require valid IBM Cloud credentials
|
|
# and should be run with pytest markers or in integration tests
|
|
|
|
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
|
|
def test_chat_completion():
|
|
"""Test chat completion endpoint."""
|
|
response = client.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "ibm/granite-3-8b-instruct",
|
|
"messages": [
|
|
{"role": "user", "content": "Hello!"}
|
|
],
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "choices" in data
|
|
assert len(data["choices"]) > 0
|
|
|
|
|
|
@pytest.mark.skip(reason="Requires valid IBM Cloud credentials")
|
|
def test_embeddings():
|
|
"""Test embeddings endpoint."""
|
|
response = client.post(
|
|
"/v1/embeddings",
|
|
json={
|
|
"model": "ibm/slate-125m-english-rtrvr",
|
|
"input": "Test text",
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "data" in data
|
|
assert len(data["data"]) > 0
|