import pytest import os from unittest.mock import AsyncMock, patch from fastapi.testclient import TestClient from main import app from database import Base, engine, SessionLocal from models import User, APIKey, Quota from crud import create_api_key, hash_password os.environ["OLLAMA_URL"] = "http://127.0.0.1:9999" def setup_test_db(): """Setup test database with required data.""" Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) db = SessionLocal() test_user = User( username="testuser", email="test@example.com", hashed_password=hash_password("test123"), is_active=True ) db.add(test_user) db.commit() db.refresh(test_user) quota = Quota( user_id=test_user.id, daily_tokens=1000000, monthly_tokens=10000000, daily_requests=1000, monthly_requests=10000 ) db.add(quota) db.commit() api_key_record, raw_key = create_api_key(db, test_user.id, "test-key") os.environ["TEST_API_KEY"] = raw_key admin_user = User( username="admin", email="admin@example.com", hashed_password=hash_password("admin123"), is_active=True, is_admin=True, ) db.add(admin_user) db.commit() db.refresh(admin_user) admin_quota = Quota( user_id=admin_user.id, daily_tokens=10000000, monthly_tokens=100000000, daily_requests=10000, monthly_requests=100000 ) db.add(admin_quota) db.commit() _, admin_raw_key = create_api_key(db, admin_user.id, "admin-key") os.environ["ADMIN_API_KEY"] = admin_raw_key db.close() return os.environ["TEST_API_KEY"], os.environ["ADMIN_API_KEY"] def teardown_test_db(): """Clean up test database and environment.""" Base.metadata.drop_all(bind=engine) os.environ.pop("TEST_API_KEY", None) os.environ.pop("ADMIN_API_KEY", None) @pytest.fixture(scope="function") def test_client(): setup_test_db() client = TestClient(app, raise_server_exceptions=False) yield client teardown_test_db() def test_auth_middleware_missing_auth(test_client): response = test_client.post("/api/generate", json={"model": "llama3", "prompt": "test"}) assert response.status_code == 401 def test_auth_middleware_invalid_key(test_client): response = test_client.post( "/api/generate", headers={"Authorization": "sk-invalid-key"}, json={"model": "llama3", "prompt": "test"} ) assert response.status_code == 401 @patch("main.proxy_request", new_callable=AsyncMock) def test_auth_middleware_valid_key(mock_proxy, test_client): mock_proxy.return_value.status_code = 200 mock_proxy.return_value.json = lambda: {"response": "success"} mock_proxy.return_value.headers = {} response = test_client.post( "/api/generate", headers={"Authorization": os.environ.get("TEST_API_KEY", "")}, json={"model": "llama3", "prompt": "test"} ) assert response.status_code == 200