diff --git a/backend/admin.py b/backend/admin.py index 100e5c7..9bb53f8 100644 --- a/backend/admin.py +++ b/backend/admin.py @@ -1,3 +1,4 @@ +import os from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.orm import Session @@ -7,15 +8,17 @@ from models import User, APIKey, Quota app = FastAPI(title="Ollama Proxy Admin API") +ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173").split(",") + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=ALLOWED_ORIGINS, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["Authorization", "Content-Type"], ) -async def require_admin_auth(request: Request): +async def require_admin_auth(request: Request, db: Session = Depends(get_db)): auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): api_key = auth_header.replace("Bearer ", "") @@ -23,24 +26,21 @@ async def require_admin_auth(request: Request): api_key = auth_header else: raise HTTPException(status_code=401, detail="Invalid or missing API key") - - db = next(get_db()) + db_key = crud.verify_api_key(db, api_key) - if not db_key: raise HTTPException(status_code=401, detail="Invalid API key") - + db_user = db.query(User).filter(User.id == db_key.user_id).first() - if not db_user or db_user.username != "admin": + if not db_user or not db_user.is_admin: raise HTTPException(status_code=403, detail="Admin access required") - + request.state.user = db_user - request.state.db = db @app.get("/api/users", response_model=list[schemas.User]) async def read_users( - skip: int = 0, - limit: int = 100, + skip: int = 0, + limit: int = 100, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -49,7 +49,7 @@ async def read_users( @app.post("/api/users", response_model=schemas.User) async def create_user( - user: schemas.UserCreate, + user: schemas.UserCreate, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -61,21 +61,24 @@ async def create_user( raise HTTPException(status_code=400, detail="Email already registered") return crud.create_user(db=db, username=user.username, email=user.email, password=user.password) -@app.post("/api/api-keys", response_model=schemas.APIKey) +@app.post("/api/api-keys", response_model=schemas.APIKeyCreated) async def create_api_key( - api_key: schemas.APIKeyCreate, + api_key: schemas.APIKeyCreate, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): db_user = db.query(User).filter(User.id == api_key.user_id).first() if not db_user: raise HTTPException(status_code=404, detail="User not found") - return crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name) + db_key, raw_key = crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name) + result = schemas.APIKeyCreated.model_validate(db_key) + result.plaintext_key = raw_key + return result @app.get("/api/api-keys", response_model=list[schemas.APIKey]) async def read_api_keys( - skip: int = 0, - limit: int = 100, + skip: int = 0, + limit: int = 100, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -84,7 +87,7 @@ async def read_api_keys( @app.put("/api/api-keys/{api_key_id}/deactivate") async def deactivate_api_key( - api_key_id: int, + api_key_id: int, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -97,15 +100,15 @@ async def deactivate_api_key( @app.put("/api/quotas/{user_id}", response_model=schemas.Quota) async def update_quota( - user_id: int, - quota: schemas.QuotaCreate, + user_id: int, + quota: schemas.QuotaCreate, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): db_quota = db.query(Quota).filter(Quota.user_id == user_id).first() if not db_quota: raise HTTPException(status_code=404, detail="Quota not found") - for key, value in quota.dict(exclude_unset=True).items(): + for key, value in quota.model_dump(exclude_unset=True).items(): setattr(db_quota, key, value) db.commit() db.refresh(db_quota) diff --git a/backend/crud.py b/backend/crud.py index 68650f8..fb17afe 100644 --- a/backend/crud.py +++ b/backend/crud.py @@ -1,10 +1,25 @@ import secrets import hashlib -from datetime import datetime, timedelta +import bcrypt +import tiktoken +from datetime import datetime from sqlalchemy.orm import Session -from database import get_db from models import APIKey, User, Quota, Usage +_encoder = tiktoken.get_encoding("cl100k_base") + +def count_tokens(text: str) -> int: + return len(_encoder.encode(text)) + +def hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) + +def _hash_api_key(key: str) -> str: + return hashlib.sha256(key.encode()).hexdigest() + def get_user_by_username(db: Session, username: str): return db.query(User).filter(User.username == username).first() @@ -14,19 +29,17 @@ def get_user_by_email(db: Session, email: str): def generate_api_key(): return "sk-" + secrets.token_urlsafe(32) -def hash_password(password: str): - return hashlib.sha256(password.encode()).hexdigest() - -def create_user(db: Session, username: str, email: str, password: str): +def create_user(db: Session, username: str, email: str, password: str, is_admin: bool = False): db_user = User( username=username, email=email, - hashed_password=hash_password(password) + hashed_password=hash_password(password), + is_admin=is_admin, ) db.add(db_user) db.commit() db.refresh(db_user) - + default_quota = Quota( user_id=db_user.id, daily_tokens=1000000, @@ -36,60 +49,75 @@ def create_user(db: Session, username: str, email: str, password: str): ) db.add(default_quota) db.commit() - + return db_user -def create_api_key(db: Session, user_id: int, name: str): - key = generate_api_key() +def create_api_key(db: Session, user_id: int, name: str) -> tuple[APIKey, str]: + raw_key = generate_api_key() db_key = APIKey( name=name, - key=key, + key=_hash_api_key(raw_key), user_id=user_id ) db.add(db_key) db.commit() db.refresh(db_key) - return db_key + return db_key, raw_key def verify_api_key(db: Session, api_key: str): - return db.query(APIKey).filter(APIKey.key == api_key, APIKey.is_active == True).first() + key_hash = _hash_api_key(api_key) + return db.query(APIKey).filter(APIKey.key == key_hash, APIKey.is_active == True).first() def get_quota(db: Session, user_id: int): return db.query(Quota).filter(Quota.user_id == user_id).first() -def check_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1): - quota = get_quota(db, user_id) - usage = get_usage(db, user_id) - - if quota.daily_tokens and (usage.tokens_used + tokens) > quota.daily_tokens: - return False - if quota.monthly_tokens and (usage.tokens_used + tokens) > quota.monthly_tokens: - return False - if quota.daily_requests and (usage.requests_count + requests) > quota.daily_requests: - return False - if quota.monthly_requests and (usage.requests_count + requests) > quota.monthly_requests: - return False - - return True - -def increment_usage(db: Session, user_id: int, tokens: int = 0, requests: int = 1): - usage = get_or_create_usage(db, user_id) - usage.tokens_used += tokens - usage.requests_count += requests - db.commit() - db.refresh(usage) +def get_quota_by_user_id(db: Session, user_id: int): + return db.query(Quota).filter(Quota.user_id == user_id).first() def get_usage(db: Session, user_id: int): return db.query(Usage).filter(Usage.user_id == user_id).first() -def get_quota_by_user_id(db: Session, user_id: int): - return db.query(Quota).filter(Quota.user_id == user_id).first() - -def get_or_create_usage(db: Session, user_id: int): - usage = get_usage(db, user_id) +def check_and_increment_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1) -> bool: + usage = ( + db.query(Usage) + .filter(Usage.user_id == user_id) + .with_for_update() + .first() + ) if not usage: usage = Usage(user_id=user_id) db.add(usage) - db.commit() - db.refresh(usage) - return usage + db.flush() + + now = datetime.utcnow() + + if usage.daily_reset_at.date() < now.date(): + usage.tokens_used_today = 0 + usage.requests_today = 0 + usage.daily_reset_at = now + + if (usage.monthly_reset_at.year, usage.monthly_reset_at.month) < (now.year, now.month): + usage.tokens_used_month = 0 + usage.requests_month = 0 + usage.monthly_reset_at = now + + quota = get_quota(db, user_id) + allowed = True + if quota: + if quota.daily_tokens and (usage.tokens_used_today + tokens) > quota.daily_tokens: + allowed = False + elif quota.monthly_tokens and (usage.tokens_used_month + tokens) > quota.monthly_tokens: + allowed = False + elif quota.daily_requests and (usage.requests_today + requests) > quota.daily_requests: + allowed = False + elif quota.monthly_requests and (usage.requests_month + requests) > quota.monthly_requests: + allowed = False + + if allowed: + usage.tokens_used_today += tokens + usage.tokens_used_month += tokens + usage.requests_today += requests + usage.requests_month += requests + + db.commit() + return allowed \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index ab36f1e..5a458d5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,7 +1,9 @@ +import time +import uuid from fastapi import FastAPI, HTTPException, Depends, Request from fastapi.responses import JSONResponse from sqlalchemy.orm import Session -from database import get_db +from database import get_db, SessionLocal import crud import httpx import os @@ -22,54 +24,48 @@ async def authenticate_and_quota(request: Request, call_next): elif auth_header.startswith("sk-"): api_key = auth_header else: - raise HTTPException(status_code=401, detail="Invalid or missing API key") - - db = next(get_db()) - db_key = crud.verify_api_key(db, api_key) - - if not db_key: - raise HTTPException(status_code=401, detail="Invalid API key") - - if not db_key.is_active: - raise HTTPException(status_code=403, detail="API key deactivated") - - request.state.user_id = db_key.user_id - + return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"}) + + db = SessionLocal() + try: + db_key = crud.verify_api_key(db, api_key) + if not db_key: + return JSONResponse(status_code=401, content={"detail": "Invalid API key"}) + if not db_key.is_active: + return JSONResponse(status_code=403, content={"detail": "API key deactivated"}) + request.state.user_id = db_key.user_id + finally: + db.close() + response = await call_next(request) return response @app.post("/api/generate") -async def generate(request: Request): - db = next(get_db()) +async def generate(request: Request, db: Session = Depends(get_db)): user_id = request.state.user_id - + body = await request.json() - - prompt_tokens = len(body.get("prompt", "").split()) - if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1): + + prompt_tokens = crud.count_tokens(body.get("prompt", "")) + if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") - + response = await proxy_request(f"{OLLAMA_URL}/api/generate", method="POST", json_data=body, headers=dict(request.headers)) - - crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1) - + return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.post("/api/chat") -async def chat(request: Request): - db = next(get_db()) +async def chat(request: Request, db: Session = Depends(get_db)): user_id = request.state.user_id - + body = await request.json() - - prompt_tokens = sum(len(msg.get("content", "").split()) for msg in body.get("messages", [])) - if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1): + + prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in body.get("messages", [])) + if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") - + response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=body, headers=dict(request.headers)) - - crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1) - + return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.get("/api/tags") @@ -101,48 +97,48 @@ async def list_openai_models(request: Request): return JSONResponse(content=openai_models, status_code=200, headers=dict(response.headers)) @app.post("/v1/chat/completions") -async def openai_chat_completions(request: Request): - db = next(get_db()) +async def openai_chat_completions(request: Request, db: Session = Depends(get_db)): user_id = request.state.user_id - + body = await request.json() - + messages = body.get("messages", []) - prompt_tokens = sum(len(msg.get("content", "").split()) for msg in messages) - - if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1): + prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in messages) + + if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") - + ollama_body = { "model": body.get("model", "llama3"), "messages": messages, "stream": body.get("stream", False) } - + response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=ollama_body, headers=dict(request.headers)) - - crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1) - + + response_content = response.json().get("message", {}).get("content", "") + completion_tokens = crud.count_tokens(response_content) + openai_response = { - "id": f"chatcmpl-{hash(msg.get('content', ''))}", + "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", - "created": int(__import__('time').time()), + "created": int(time.time()), "model": body.get("model", "llama3"), "choices": [ { "index": 0, "message": { "role": "assistant", - "content": response.json().get("message", {}).get("content", "") + "content": response_content }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": prompt_tokens, - "completion_tokens": len(response.json().get("message", {}).get("content", "").split()), - "total_tokens": prompt_tokens + len(response.json().get("message", {}).get("content", "").split()) + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens } } - + return JSONResponse(content=openai_response, status_code=200, headers={"Content-Type": "application/json"}) diff --git a/backend/models.py b/backend/models.py index a9fd804..071955f 100644 --- a/backend/models.py +++ b/backend/models.py @@ -10,6 +10,7 @@ class User(Base): email = Column(String, unique=True, index=True) hashed_password = Column(String) is_active = Column(Boolean, default=True) + is_admin = Column(Boolean, default=False) created_at = Column(DateTime, default=datetime.utcnow) class APIKey(Base): @@ -38,6 +39,9 @@ class Usage(Base): id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, ForeignKey("users.id"), unique=True) - tokens_used = Column(BigInteger, default=0) - requests_count = Column(Integer, default=0) - reset_at = Column(DateTime, default=datetime.utcnow) + tokens_used_today = Column(BigInteger, default=0) + tokens_used_month = Column(BigInteger, default=0) + requests_today = Column(Integer, default=0) + requests_month = Column(Integer, default=0) + daily_reset_at = Column(DateTime, default=datetime.utcnow) + monthly_reset_at = Column(DateTime, default=datetime.utcnow) diff --git a/backend/requirements.txt b/backend/requirements.txt index 22ffceb..66f0d51 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,7 +6,8 @@ alembic==1.14.0 pydantic==2.10.3 python-multipart==0.0.20 python-jose[cryptography]==3.3.0 -passlib[bcrypt]==1.7.4 +bcrypt==5.0.0 +tiktoken==0.9.0 python-dotenv==1.0.1 pytest==8.3.4 pytest-asyncio==0.25.1 diff --git a/backend/schemas.py b/backend/schemas.py index 55cdaec..cc3cd1d 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -5,6 +5,7 @@ from typing import Optional class UserBase(BaseModel): username: str email: str + is_admin: bool = False class UserCreate(UserBase): password: str @@ -33,6 +34,12 @@ class APIKey(APIKeyBase): class Config: from_attributes = True +class APIKeyCreated(APIKey): + plaintext_key: str + + class Config: + from_attributes = True + class QuotaBase(BaseModel): daily_tokens: Optional[int] = None monthly_tokens: Optional[int] = None @@ -51,9 +58,12 @@ class Quota(QuotaBase): from_attributes = True class UsageStats(BaseModel): - tokens_used: int = 0 - requests_count: int = 0 - last_reset: Optional[datetime] = None + tokens_used_today: int = 0 + tokens_used_month: int = 0 + requests_today: int = 0 + requests_month: int = 0 + daily_reset_at: Optional[datetime] = None + monthly_reset_at: Optional[datetime] = None class Config: from_attributes = True diff --git a/backend/setup_admin.py b/backend/setup_admin.py index daedf2c..6ad72f1 100644 --- a/backend/setup_admin.py +++ b/backend/setup_admin.py @@ -13,7 +13,8 @@ def setup_admin(): username="admin", email="admin@ollama.local", hashed_password=hash_password("admin123"), - is_active=True + is_active=True, + is_admin=True, ) db.add(admin_user) db.commit() @@ -31,8 +32,8 @@ def setup_admin(): db.commit() print("✓ Admin quota created") - api_key = create_api_key(db, admin_user.id, "admin-api-key") - print(f"✓ Admin API Key: {api_key.key}") + _, raw_key = create_api_key(db, admin_user.id, "admin-api-key") + print(f"✓ Admin API Key: {raw_key}") else: print("✗ Admin user already exists") db.close() diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..73e3b1b --- /dev/null +++ b/backend/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Ollama Proxy.""" diff --git a/backend/tests/__pycache__/__init__.cpython-312.pyc b/backend/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..2751027 Binary files /dev/null and b/backend/tests/__pycache__/__init__.cpython-312.pyc differ diff --git a/backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc b/backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc new file mode 100644 index 0000000..e95f4ef Binary files /dev/null and b/backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc differ diff --git a/backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc b/backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc new file mode 100644 index 0000000..4a0c115 Binary files /dev/null and b/backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc differ diff --git a/backend/tests/__pycache__/test_quota.cpython-312-pytest-8.3.4.pyc b/backend/tests/__pycache__/test_quota.cpython-312-pytest-8.3.4.pyc new file mode 100644 index 0000000..ed40ea5 Binary files /dev/null and b/backend/tests/__pycache__/test_quota.cpython-312-pytest-8.3.4.pyc differ diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..cdb70be --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,87 @@ +import pytest +from fastapi.testclient import TestClient +import tempfile +import os +from pathlib import Path + +def create_test_db(): + """Create a temporary SQLite database for tests.""" + temp_db = tempfile.NamedTemporaryFile(suffix='.db', delete=False) + temp_db.close() + os.environ["DATABASE_URL"] = f"sqlite:///{temp_db.name}" + return temp_db.name + +def cleanup_test_db(db_path): + """Remove the temporary database.""" + if os.path.exists(db_path): + os.unlink(db_path) + os.environ.pop("DATABASE_URL", None) + +def setup_test_db(): + """Setup test database with required data.""" + from database import Base, engine, SessionLocal + from models import User, APIKey, Quota, Usage + from crud import create_api_key, hash_password + + # Create tables + Base.metadata.create_all(bind=engine) + + db = SessionLocal() + + # Create test user + test_user = User( + username="testuser", + email="test@example.com", + hashed_password=hash_password("test123"), + is_active=True + ) + db.add(test_user) + db.commit() + db.refresh(test_user) + + # Create API key for test user + _, raw_key = create_api_key(db, test_user.id, "test-key") + os.environ["TEST_API_KEY"] = raw_key + + # Create admin user + admin_user = User( + username="admin", + email="admin@example.com", + hashed_password=hash_password("admin123"), + is_active=True, + is_admin=True, + ) + db.add(admin_user) + db.commit() + db.refresh(admin_user) + + # Create admin API key + _, admin_raw_key = create_api_key(db, admin_user.id, "admin-key") + os.environ["ADMIN_API_KEY"] = admin_raw_key + + db.close() + + return raw_key, admin_raw_key + +def teardown_test_db(): + """Clean up test database and environment.""" + from database import engine + from models import Base + Base.metadata.drop_all(bind=engine) + + os.environ.pop("TEST_API_KEY", None) + os.environ.pop("ADMIN_API_KEY", None) + +@pytest.fixture(scope="session") +def test_client(): + """Create test client with test database.""" + db_path = create_test_db() + setup_test_db() + + from main import app + client = TestClient(app) + + yield client + + teardown_test_db() + cleanup_test_db(db_path) diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py new file mode 100644 index 0000000..853c863 --- /dev/null +++ b/backend/tests/test_auth.py @@ -0,0 +1,106 @@ +import pytest +import os +from unittest.mock import AsyncMock, patch +from fastapi.testclient import TestClient +from main import app +from database import Base, engine, SessionLocal +from models import User, APIKey, Quota +from crud import create_api_key, hash_password + +os.environ["OLLAMA_URL"] = "http://127.0.0.1:9999" + +def setup_test_db(): + """Setup test database with required data.""" + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + + db = SessionLocal() + + test_user = User( + username="testuser", + email="test@example.com", + hashed_password=hash_password("test123"), + is_active=True + ) + db.add(test_user) + db.commit() + db.refresh(test_user) + + quota = Quota( + user_id=test_user.id, + daily_tokens=1000000, + monthly_tokens=10000000, + daily_requests=1000, + monthly_requests=10000 + ) + db.add(quota) + db.commit() + + api_key_record, raw_key = create_api_key(db, test_user.id, "test-key") + os.environ["TEST_API_KEY"] = raw_key + + admin_user = User( + username="admin", + email="admin@example.com", + hashed_password=hash_password("admin123"), + is_active=True, + is_admin=True, + ) + db.add(admin_user) + db.commit() + db.refresh(admin_user) + + admin_quota = Quota( + user_id=admin_user.id, + daily_tokens=10000000, + monthly_tokens=100000000, + daily_requests=10000, + monthly_requests=100000 + ) + db.add(admin_quota) + db.commit() + + _, admin_raw_key = create_api_key(db, admin_user.id, "admin-key") + os.environ["ADMIN_API_KEY"] = admin_raw_key + + db.close() + + return os.environ["TEST_API_KEY"], os.environ["ADMIN_API_KEY"] + +def teardown_test_db(): + """Clean up test database and environment.""" + Base.metadata.drop_all(bind=engine) + os.environ.pop("TEST_API_KEY", None) + os.environ.pop("ADMIN_API_KEY", None) + +@pytest.fixture(scope="function") +def test_client(): + setup_test_db() + client = TestClient(app, raise_server_exceptions=False) + yield client + teardown_test_db() + +def test_auth_middleware_missing_auth(test_client): + response = test_client.post("/api/generate", json={"model": "llama3", "prompt": "test"}) + assert response.status_code == 401 + +def test_auth_middleware_invalid_key(test_client): + response = test_client.post( + "/api/generate", + headers={"Authorization": "sk-invalid-key"}, + json={"model": "llama3", "prompt": "test"} + ) + assert response.status_code == 401 + +@patch("main.proxy_request", new_callable=AsyncMock) +def test_auth_middleware_valid_key(mock_proxy, test_client): + mock_proxy.return_value.status_code = 200 + mock_proxy.return_value.json = lambda: {"response": "success"} + mock_proxy.return_value.headers = {} + + response = test_client.post( + "/api/generate", + headers={"Authorization": os.environ.get("TEST_API_KEY", "")}, + json={"model": "llama3", "prompt": "test"} + ) + assert response.status_code == 200 diff --git a/backend/tests/test_quota.py b/backend/tests/test_quota.py new file mode 100644 index 0000000..f2ab9e6 --- /dev/null +++ b/backend/tests/test_quota.py @@ -0,0 +1,182 @@ +import pytest +import os +from datetime import datetime, timedelta + +os.environ.setdefault("OLLAMA_URL", "http://127.0.0.1:9999") + +from database import Base, engine, SessionLocal +from models import User, Quota, Usage +from crud import check_and_increment_quota, count_tokens, hash_password + + +def make_user_and_quota(db, daily_tokens=None, monthly_tokens=None, + daily_requests=None, monthly_requests=None): + user = User( + username="quotauser", + email="quota@example.com", + hashed_password=hash_password("pass"), + is_active=True, + ) + db.add(user) + db.commit() + db.refresh(user) + + quota = Quota( + user_id=user.id, + daily_tokens=daily_tokens, + monthly_tokens=monthly_tokens, + daily_requests=daily_requests, + monthly_requests=monthly_requests, + ) + db.add(quota) + db.commit() + + return user.id + + +@pytest.fixture +def db(): + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + session = SessionLocal() + yield session + session.close() + Base.metadata.drop_all(bind=engine) + + +# --- count_tokens --- + +def test_count_tokens_empty(): + assert count_tokens("") == 0 + +def test_count_tokens_returns_int(): + assert isinstance(count_tokens("hello world"), int) + +def test_count_tokens_scales_with_length(): + assert count_tokens("a b c d e f g h") > count_tokens("a") + +def test_count_tokens_more_accurate_than_split(): + # tiktoken counts "don't" as 2 tokens, split counts it as 1 word + text = "don't do that" + assert count_tokens(text) >= len(text.split()) + + +# --- check_and_increment_quota --- + +def test_allowed_within_daily_token_limit(db): + user_id = make_user_and_quota(db, daily_tokens=1000) + assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is True + +def test_denied_when_daily_tokens_exceeded(db): + user_id = make_user_and_quota(db, daily_tokens=50) + assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False + +def test_denied_when_monthly_tokens_exceeded(db): + user_id = make_user_and_quota(db, monthly_tokens=50) + assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False + +def test_denied_when_daily_requests_exceeded(db): + user_id = make_user_and_quota(db, daily_requests=1) + check_and_increment_quota(db, user_id, tokens=0, requests=1) + assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False + +def test_denied_when_monthly_requests_exceeded(db): + user_id = make_user_and_quota(db, monthly_requests=1) + check_and_increment_quota(db, user_id, tokens=0, requests=1) + assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False + +def test_increments_both_daily_and_monthly_counters(db): + user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000, + daily_requests=100, monthly_requests=1000) + check_and_increment_quota(db, user_id, tokens=50, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 50 + assert usage.tokens_used_month == 50 + assert usage.requests_today == 1 + assert usage.requests_month == 1 + +def test_creates_usage_record_on_first_call(db): + user_id = make_user_and_quota(db, daily_tokens=1000) + assert db.query(Usage).filter(Usage.user_id == user_id).first() is None + + check_and_increment_quota(db, user_id, tokens=10, requests=1) + + assert db.query(Usage).filter(Usage.user_id == user_id).first() is not None + +def test_no_quota_allows_any_request(db): + user_id = make_user_and_quota(db) # all limits None + assert check_and_increment_quota(db, user_id, tokens=999999, requests=9999) is True + +def test_cumulative_usage_across_calls(db): + user_id = make_user_and_quota(db, daily_tokens=200) + check_and_increment_quota(db, user_id, tokens=100, requests=1) + check_and_increment_quota(db, user_id, tokens=99, requests=1) + # 199 used, 1 remaining – exactly 1 more token should pass + assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is True + # Now 200 used – next request must fail + assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is False + + +# --- Reset logic --- + +def test_daily_reset_restores_access(db): + user_id = make_user_and_quota(db, daily_tokens=100) + check_and_increment_quota(db, user_id, tokens=90, requests=1) + + # Backdate daily_reset_at to yesterday + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) + db.commit() + + # Should pass again after reset + assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 90 + +def test_daily_reset_does_not_affect_monthly_counter(db): + user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000) + check_and_increment_quota(db, user_id, tokens=50, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) + db.commit() + + check_and_increment_quota(db, user_id, tokens=50, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 50 + assert usage.tokens_used_month == 100 # cumulative across days + +def test_monthly_reset_restores_access(db): + user_id = make_user_and_quota(db, monthly_tokens=100) + check_and_increment_quota(db, user_id, tokens=90, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.monthly_reset_at = datetime.utcnow() - timedelta(days=32) + db.commit() + + assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_month == 90 + +def test_failed_quota_check_still_commits_reset(db): + user_id = make_user_and_quota(db, daily_tokens=100, daily_requests=5) + check_and_increment_quota(db, user_id, tokens=80, requests=1) + + # Backdate so a reset fires, but the new request still exceeds the limit + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) + usage.tokens_used_today = 80 + db.commit() + + # After reset tokens_used_today = 0; 200 tokens exceeds 100 limit + result = check_and_increment_quota(db, user_id, tokens=200, requests=1) + assert result is False + + # Reset must still be persisted so the next request sees fresh counters + db.expire_all() + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 0