- POST /v1/messages endpoint with full quota enforcement and auth - Accepts x-api-key and anthropic-auth-token headers (for Claude Code) - Transforms Anthropic request/response format ↔ Ollama /api/chat - Streaming support via Anthropic SSE format - Tool use support (request and response transformation) - ANTHROPIC_DEFAULT_MODEL env var for model selection without admin UI - BACKEND_API_KEY env var for forwarding auth to upstream proxies - Fix SQLite path always resolved relative to database.py location - start.sh and start_claude.sh load .env relative to script location
273 lines
9.8 KiB
Python
273 lines
9.8 KiB
Python
import json
|
|
import os
|
|
from unittest.mock import AsyncMock, MagicMock, patch, call
|
|
|
|
|
|
def _make_body(model="llama3", messages=None, stream=False, **kwargs):
|
|
body = {
|
|
"model": model,
|
|
"messages": messages or [{"role": "user", "content": "Hello"}],
|
|
"max_tokens": 100,
|
|
}
|
|
if stream:
|
|
body["stream"] = True
|
|
body.update(kwargs)
|
|
return body
|
|
|
|
|
|
def _ollama_chat_response(content="Hi!", input_tokens=5, output_tokens=3):
|
|
return {
|
|
"model": "llama3",
|
|
"message": {"role": "assistant", "content": content},
|
|
"prompt_eval_count": input_tokens,
|
|
"eval_count": output_tokens,
|
|
"done": True,
|
|
}
|
|
|
|
|
|
# --- Auth ---
|
|
|
|
def test_messages_missing_auth_returns_401(test_client):
|
|
response = test_client.post("/v1/messages", json=_make_body())
|
|
assert response.status_code == 401
|
|
|
|
|
|
def test_messages_invalid_key_returns_401(test_client):
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"x-api-key": "sk-invalid"},
|
|
json=_make_body(),
|
|
)
|
|
assert response.status_code == 401
|
|
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_accepts_anthropic_auth_token_header(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: _ollama_chat_response()
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"anthropic-auth-token": os.environ.get("TEST_API_KEY", "")},
|
|
json=_make_body(),
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_accepts_x_api_key_header(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: _ollama_chat_response()
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"x-api-key": os.environ.get("TEST_API_KEY", "")},
|
|
json=_make_body(),
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
# --- Validation ---
|
|
|
|
def test_messages_missing_model_returns_422(test_client):
|
|
env = {k: v for k, v in os.environ.items() if k != "ANTHROPIC_DEFAULT_MODEL"}
|
|
with patch.dict(os.environ, env, clear=True):
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json={"messages": [{"role": "user", "content": "Hi"}], "max_tokens": 100},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_anthropic_default_model_used_when_no_model_in_request(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: _ollama_chat_response()
|
|
with patch.dict(os.environ, {"ANTHROPIC_DEFAULT_MODEL": "qwen3-coder:q8_0"}):
|
|
test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json={"messages": [{"role": "user", "content": "Hi"}], "max_tokens": 100},
|
|
)
|
|
sent_body = mock_proxy.call_args[1]["json_data"]
|
|
assert sent_body["model"] == "qwen3-coder:q8_0"
|
|
|
|
|
|
# --- Quota ---
|
|
|
|
def test_messages_quota_exceeded_returns_429(test_client):
|
|
with patch("main.crud.check_and_increment_quota", return_value=False):
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json=_make_body(),
|
|
)
|
|
assert response.status_code == 429
|
|
|
|
|
|
# --- Response format ---
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_returns_anthropic_format(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: _ollama_chat_response("Hello!")
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json=_make_body(),
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["type"] == "message"
|
|
assert data["role"] == "assistant"
|
|
assert isinstance(data["content"], list)
|
|
assert data["content"][0]["type"] == "text"
|
|
assert data["content"][0]["text"] == "Hello!"
|
|
assert data["usage"]["input_tokens"] == 5
|
|
assert data["usage"]["output_tokens"] == 3
|
|
|
|
|
|
# --- Request transformation ---
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_system_prompt_becomes_first_system_message(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: _ollama_chat_response()
|
|
test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json=_make_body(system="You are helpful"),
|
|
)
|
|
sent_body = mock_proxy.call_args[1]["json_data"]
|
|
assert sent_body["messages"][0]["role"] == "system"
|
|
assert sent_body["messages"][0]["content"] == "You are helpful"
|
|
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_tools_transformed_to_ollama_function_format(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: _ollama_chat_response()
|
|
test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json=_make_body(tools=[{
|
|
"name": "bash",
|
|
"description": "Run bash",
|
|
"input_schema": {"type": "object", "properties": {"command": {"type": "string"}}},
|
|
}]),
|
|
)
|
|
sent_body = mock_proxy.call_args[1]["json_data"]
|
|
assert sent_body["tools"][0]["type"] == "function"
|
|
assert sent_body["tools"][0]["function"]["name"] == "bash"
|
|
assert "parameters" in sent_body["tools"][0]["function"]
|
|
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_tool_call_response_transformed_to_anthropic(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: {
|
|
"model": "llama3",
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [{"function": {"name": "bash", "arguments": {"command": "ls"}}}],
|
|
},
|
|
"prompt_eval_count": 10,
|
|
"eval_count": 5,
|
|
"done": True,
|
|
}
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json=_make_body(),
|
|
)
|
|
data = response.json()
|
|
assert data["stop_reason"] == "tool_use"
|
|
tool_block = next(b for b in data["content"] if b["type"] == "tool_use")
|
|
assert tool_block["name"] == "bash"
|
|
assert tool_block["input"] == {"command": "ls"}
|
|
|
|
|
|
# --- Streaming ---
|
|
|
|
@patch("main.proxy_request", new_callable=AsyncMock)
|
|
def test_messages_streaming_returns_anthropic_sse_events(mock_proxy, test_client):
|
|
mock_proxy.return_value.status_code = 200
|
|
mock_proxy.return_value.json = lambda: {
|
|
"model": "llama3",
|
|
"message": {"role": "assistant", "content": "Hi!"},
|
|
"prompt_eval_count": 5,
|
|
"eval_count": 3,
|
|
"done": True,
|
|
}
|
|
|
|
response = test_client.post(
|
|
"/v1/messages",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json=_make_body(stream=True),
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
events = [
|
|
json.loads(line[6:])
|
|
for line in response.text.splitlines()
|
|
if line.startswith("data: ")
|
|
]
|
|
event_types = [e["type"] for e in events]
|
|
assert "message_start" in event_types
|
|
assert "content_block_start" in event_types
|
|
assert "content_block_delta" in event_types
|
|
assert "message_stop" in event_types
|
|
|
|
deltas = [e for e in events if e["type"] == "content_block_delta"]
|
|
text = "".join(d["delta"]["text"] for d in deltas)
|
|
assert text == "Hi!"
|
|
|
|
|
|
# --- Backend-Auth (BACKEND_API_KEY) ---
|
|
|
|
def test_proxy_request_forwards_backend_api_key(test_client):
|
|
with patch("main.httpx.AsyncClient") as mock_cls:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"result": "ok"}
|
|
|
|
mock_instance = AsyncMock()
|
|
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
|
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
|
mock_instance.request = AsyncMock(return_value=mock_response)
|
|
mock_cls.return_value = mock_instance
|
|
|
|
with patch.dict(os.environ, {"BACKEND_API_KEY": "sk-backend-secret"}):
|
|
test_client.post(
|
|
"/api/generate",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json={"model": "llama3", "prompt": "hi"},
|
|
)
|
|
|
|
_, kwargs = mock_instance.request.call_args
|
|
assert kwargs.get("headers", {}).get("Authorization") == "Bearer sk-backend-secret"
|
|
|
|
|
|
def test_proxy_request_omits_auth_header_when_no_backend_key(test_client):
|
|
with patch("main.httpx.AsyncClient") as mock_cls:
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"result": "ok"}
|
|
|
|
mock_instance = AsyncMock()
|
|
mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
|
|
mock_instance.__aexit__ = AsyncMock(return_value=False)
|
|
mock_instance.request = AsyncMock(return_value=mock_response)
|
|
mock_cls.return_value = mock_instance
|
|
|
|
env_without_key = {k: v for k, v in os.environ.items() if k != "BACKEND_API_KEY"}
|
|
with patch.dict(os.environ, env_without_key, clear=True):
|
|
test_client.post(
|
|
"/api/generate",
|
|
headers={"Authorization": f"Bearer {os.environ.get('TEST_API_KEY', '')}"},
|
|
json={"model": "llama3", "prompt": "hi"},
|
|
)
|
|
|
|
_, kwargs = mock_instance.request.call_args
|
|
assert "Authorization" not in kwargs.get("headers", {})
|