Fix critical/high security and correctness issues from code review

Critical (all fixed):
- bcrypt statt SHA-256 für Passwörter
- API-Keys gehasht in DB, Plaintext nur einmalig zurückgegeben
- DB-Session-Leak behoben (SessionLocal + try/finally, Depends(get_db))
- Admin-Check via is_admin-Spalte statt Hardcoded-Username
- CORS: konfigurierbare Origins via ALLOWED_ORIGINS, kein Wildcard mit Credentials

High (all fixed):
- TOCTOU-Race: check_and_increment_quota mit SELECT FOR UPDATE atomar
- Getrennte Tages-/Monatszähler in Usage + automatische Reset-Logik
- Token-Zählung mit tiktoken (cl100k_base) statt .split()

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Oliver Hofmann 2026-04-27 21:34:17 +02:00
parent 562f6ecd9c
commit bf694b79e2
15 changed files with 547 additions and 128 deletions

View File

@ -1,3 +1,4 @@
import os
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
@ -7,15 +8,17 @@ from models import User, APIKey, Quota
app = FastAPI(title="Ollama Proxy Admin API")
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Authorization", "Content-Type"],
)
async def require_admin_auth(request: Request):
async def require_admin_auth(request: Request, db: Session = Depends(get_db)):
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
api_key = auth_header.replace("Bearer ", "")
@ -23,24 +26,21 @@ async def require_admin_auth(request: Request):
api_key = auth_header
else:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
db = next(get_db())
db_key = crud.verify_api_key(db, api_key)
if not db_key:
raise HTTPException(status_code=401, detail="Invalid API key")
db_user = db.query(User).filter(User.id == db_key.user_id).first()
if not db_user or db_user.username != "admin":
if not db_user or not db_user.is_admin:
raise HTTPException(status_code=403, detail="Admin access required")
request.state.user = db_user
request.state.db = db
@app.get("/api/users", response_model=list[schemas.User])
async def read_users(
skip: int = 0,
limit: int = 100,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
_ = Depends(require_admin_auth)
):
@ -49,7 +49,7 @@ async def read_users(
@app.post("/api/users", response_model=schemas.User)
async def create_user(
user: schemas.UserCreate,
user: schemas.UserCreate,
db: Session = Depends(get_db),
_ = Depends(require_admin_auth)
):
@ -61,21 +61,24 @@ async def create_user(
raise HTTPException(status_code=400, detail="Email already registered")
return crud.create_user(db=db, username=user.username, email=user.email, password=user.password)
@app.post("/api/api-keys", response_model=schemas.APIKey)
@app.post("/api/api-keys", response_model=schemas.APIKeyCreated)
async def create_api_key(
api_key: schemas.APIKeyCreate,
api_key: schemas.APIKeyCreate,
db: Session = Depends(get_db),
_ = Depends(require_admin_auth)
):
db_user = db.query(User).filter(User.id == api_key.user_id).first()
if not db_user:
raise HTTPException(status_code=404, detail="User not found")
return crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name)
db_key, raw_key = crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name)
result = schemas.APIKeyCreated.model_validate(db_key)
result.plaintext_key = raw_key
return result
@app.get("/api/api-keys", response_model=list[schemas.APIKey])
async def read_api_keys(
skip: int = 0,
limit: int = 100,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
_ = Depends(require_admin_auth)
):
@ -84,7 +87,7 @@ async def read_api_keys(
@app.put("/api/api-keys/{api_key_id}/deactivate")
async def deactivate_api_key(
api_key_id: int,
api_key_id: int,
db: Session = Depends(get_db),
_ = Depends(require_admin_auth)
):
@ -97,15 +100,15 @@ async def deactivate_api_key(
@app.put("/api/quotas/{user_id}", response_model=schemas.Quota)
async def update_quota(
user_id: int,
quota: schemas.QuotaCreate,
user_id: int,
quota: schemas.QuotaCreate,
db: Session = Depends(get_db),
_ = Depends(require_admin_auth)
):
db_quota = db.query(Quota).filter(Quota.user_id == user_id).first()
if not db_quota:
raise HTTPException(status_code=404, detail="Quota not found")
for key, value in quota.dict(exclude_unset=True).items():
for key, value in quota.model_dump(exclude_unset=True).items():
setattr(db_quota, key, value)
db.commit()
db.refresh(db_quota)

View File

@ -1,10 +1,25 @@
import secrets
import hashlib
from datetime import datetime, timedelta
import bcrypt
import tiktoken
from datetime import datetime
from sqlalchemy.orm import Session
from database import get_db
from models import APIKey, User, Quota, Usage
_encoder = tiktoken.get_encoding("cl100k_base")
def count_tokens(text: str) -> int:
return len(_encoder.encode(text))
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def _hash_api_key(key: str) -> str:
return hashlib.sha256(key.encode()).hexdigest()
def get_user_by_username(db: Session, username: str):
return db.query(User).filter(User.username == username).first()
@ -14,19 +29,17 @@ def get_user_by_email(db: Session, email: str):
def generate_api_key():
return "sk-" + secrets.token_urlsafe(32)
def hash_password(password: str):
return hashlib.sha256(password.encode()).hexdigest()
def create_user(db: Session, username: str, email: str, password: str):
def create_user(db: Session, username: str, email: str, password: str, is_admin: bool = False):
db_user = User(
username=username,
email=email,
hashed_password=hash_password(password)
hashed_password=hash_password(password),
is_admin=is_admin,
)
db.add(db_user)
db.commit()
db.refresh(db_user)
default_quota = Quota(
user_id=db_user.id,
daily_tokens=1000000,
@ -36,60 +49,75 @@ def create_user(db: Session, username: str, email: str, password: str):
)
db.add(default_quota)
db.commit()
return db_user
def create_api_key(db: Session, user_id: int, name: str):
key = generate_api_key()
def create_api_key(db: Session, user_id: int, name: str) -> tuple[APIKey, str]:
raw_key = generate_api_key()
db_key = APIKey(
name=name,
key=key,
key=_hash_api_key(raw_key),
user_id=user_id
)
db.add(db_key)
db.commit()
db.refresh(db_key)
return db_key
return db_key, raw_key
def verify_api_key(db: Session, api_key: str):
return db.query(APIKey).filter(APIKey.key == api_key, APIKey.is_active == True).first()
key_hash = _hash_api_key(api_key)
return db.query(APIKey).filter(APIKey.key == key_hash, APIKey.is_active == True).first()
def get_quota(db: Session, user_id: int):
return db.query(Quota).filter(Quota.user_id == user_id).first()
def check_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1):
quota = get_quota(db, user_id)
usage = get_usage(db, user_id)
if quota.daily_tokens and (usage.tokens_used + tokens) > quota.daily_tokens:
return False
if quota.monthly_tokens and (usage.tokens_used + tokens) > quota.monthly_tokens:
return False
if quota.daily_requests and (usage.requests_count + requests) > quota.daily_requests:
return False
if quota.monthly_requests and (usage.requests_count + requests) > quota.monthly_requests:
return False
return True
def increment_usage(db: Session, user_id: int, tokens: int = 0, requests: int = 1):
usage = get_or_create_usage(db, user_id)
usage.tokens_used += tokens
usage.requests_count += requests
db.commit()
db.refresh(usage)
def get_quota_by_user_id(db: Session, user_id: int):
return db.query(Quota).filter(Quota.user_id == user_id).first()
def get_usage(db: Session, user_id: int):
return db.query(Usage).filter(Usage.user_id == user_id).first()
def get_quota_by_user_id(db: Session, user_id: int):
return db.query(Quota).filter(Quota.user_id == user_id).first()
def get_or_create_usage(db: Session, user_id: int):
usage = get_usage(db, user_id)
def check_and_increment_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1) -> bool:
usage = (
db.query(Usage)
.filter(Usage.user_id == user_id)
.with_for_update()
.first()
)
if not usage:
usage = Usage(user_id=user_id)
db.add(usage)
db.commit()
db.refresh(usage)
return usage
db.flush()
now = datetime.utcnow()
if usage.daily_reset_at.date() < now.date():
usage.tokens_used_today = 0
usage.requests_today = 0
usage.daily_reset_at = now
if (usage.monthly_reset_at.year, usage.monthly_reset_at.month) < (now.year, now.month):
usage.tokens_used_month = 0
usage.requests_month = 0
usage.monthly_reset_at = now
quota = get_quota(db, user_id)
allowed = True
if quota:
if quota.daily_tokens and (usage.tokens_used_today + tokens) > quota.daily_tokens:
allowed = False
elif quota.monthly_tokens and (usage.tokens_used_month + tokens) > quota.monthly_tokens:
allowed = False
elif quota.daily_requests and (usage.requests_today + requests) > quota.daily_requests:
allowed = False
elif quota.monthly_requests and (usage.requests_month + requests) > quota.monthly_requests:
allowed = False
if allowed:
usage.tokens_used_today += tokens
usage.tokens_used_month += tokens
usage.requests_today += requests
usage.requests_month += requests
db.commit()
return allowed

View File

@ -1,7 +1,9 @@
import time
import uuid
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from database import get_db
from database import get_db, SessionLocal
import crud
import httpx
import os
@ -22,54 +24,48 @@ async def authenticate_and_quota(request: Request, call_next):
elif auth_header.startswith("sk-"):
api_key = auth_header
else:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
db = next(get_db())
db_key = crud.verify_api_key(db, api_key)
if not db_key:
raise HTTPException(status_code=401, detail="Invalid API key")
if not db_key.is_active:
raise HTTPException(status_code=403, detail="API key deactivated")
request.state.user_id = db_key.user_id
return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"})
db = SessionLocal()
try:
db_key = crud.verify_api_key(db, api_key)
if not db_key:
return JSONResponse(status_code=401, content={"detail": "Invalid API key"})
if not db_key.is_active:
return JSONResponse(status_code=403, content={"detail": "API key deactivated"})
request.state.user_id = db_key.user_id
finally:
db.close()
response = await call_next(request)
return response
@app.post("/api/generate")
async def generate(request: Request):
db = next(get_db())
async def generate(request: Request, db: Session = Depends(get_db)):
user_id = request.state.user_id
body = await request.json()
prompt_tokens = len(body.get("prompt", "").split())
if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1):
prompt_tokens = crud.count_tokens(body.get("prompt", ""))
if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1):
raise HTTPException(status_code=429, detail="Quota exceeded")
response = await proxy_request(f"{OLLAMA_URL}/api/generate", method="POST", json_data=body, headers=dict(request.headers))
crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1)
return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers))
@app.post("/api/chat")
async def chat(request: Request):
db = next(get_db())
async def chat(request: Request, db: Session = Depends(get_db)):
user_id = request.state.user_id
body = await request.json()
prompt_tokens = sum(len(msg.get("content", "").split()) for msg in body.get("messages", []))
if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1):
prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in body.get("messages", []))
if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1):
raise HTTPException(status_code=429, detail="Quota exceeded")
response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=body, headers=dict(request.headers))
crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1)
return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers))
@app.get("/api/tags")
@ -101,48 +97,48 @@ async def list_openai_models(request: Request):
return JSONResponse(content=openai_models, status_code=200, headers=dict(response.headers))
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: Request):
db = next(get_db())
async def openai_chat_completions(request: Request, db: Session = Depends(get_db)):
user_id = request.state.user_id
body = await request.json()
messages = body.get("messages", [])
prompt_tokens = sum(len(msg.get("content", "").split()) for msg in messages)
if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1):
prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in messages)
if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1):
raise HTTPException(status_code=429, detail="Quota exceeded")
ollama_body = {
"model": body.get("model", "llama3"),
"messages": messages,
"stream": body.get("stream", False)
}
response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=ollama_body, headers=dict(request.headers))
crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1)
response_content = response.json().get("message", {}).get("content", "")
completion_tokens = crud.count_tokens(response_content)
openai_response = {
"id": f"chatcmpl-{hash(msg.get('content', ''))}",
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(__import__('time').time()),
"created": int(time.time()),
"model": body.get("model", "llama3"),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response.json().get("message", {}).get("content", "")
"content": response_content
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": len(response.json().get("message", {}).get("content", "").split()),
"total_tokens": prompt_tokens + len(response.json().get("message", {}).get("content", "").split())
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
return JSONResponse(content=openai_response, status_code=200, headers={"Content-Type": "application/json"})

View File

@ -10,6 +10,7 @@ class User(Base):
email = Column(String, unique=True, index=True)
hashed_password = Column(String)
is_active = Column(Boolean, default=True)
is_admin = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.utcnow)
class APIKey(Base):
@ -38,6 +39,9 @@ class Usage(Base):
id = Column(Integer, primary_key=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), unique=True)
tokens_used = Column(BigInteger, default=0)
requests_count = Column(Integer, default=0)
reset_at = Column(DateTime, default=datetime.utcnow)
tokens_used_today = Column(BigInteger, default=0)
tokens_used_month = Column(BigInteger, default=0)
requests_today = Column(Integer, default=0)
requests_month = Column(Integer, default=0)
daily_reset_at = Column(DateTime, default=datetime.utcnow)
monthly_reset_at = Column(DateTime, default=datetime.utcnow)

View File

@ -6,7 +6,8 @@ alembic==1.14.0
pydantic==2.10.3
python-multipart==0.0.20
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
bcrypt==5.0.0
tiktoken==0.9.0
python-dotenv==1.0.1
pytest==8.3.4
pytest-asyncio==0.25.1

View File

@ -5,6 +5,7 @@ from typing import Optional
class UserBase(BaseModel):
username: str
email: str
is_admin: bool = False
class UserCreate(UserBase):
password: str
@ -33,6 +34,12 @@ class APIKey(APIKeyBase):
class Config:
from_attributes = True
class APIKeyCreated(APIKey):
plaintext_key: str
class Config:
from_attributes = True
class QuotaBase(BaseModel):
daily_tokens: Optional[int] = None
monthly_tokens: Optional[int] = None
@ -51,9 +58,12 @@ class Quota(QuotaBase):
from_attributes = True
class UsageStats(BaseModel):
tokens_used: int = 0
requests_count: int = 0
last_reset: Optional[datetime] = None
tokens_used_today: int = 0
tokens_used_month: int = 0
requests_today: int = 0
requests_month: int = 0
daily_reset_at: Optional[datetime] = None
monthly_reset_at: Optional[datetime] = None
class Config:
from_attributes = True

View File

@ -13,7 +13,8 @@ def setup_admin():
username="admin",
email="admin@ollama.local",
hashed_password=hash_password("admin123"),
is_active=True
is_active=True,
is_admin=True,
)
db.add(admin_user)
db.commit()
@ -31,8 +32,8 @@ def setup_admin():
db.commit()
print("✓ Admin quota created")
api_key = create_api_key(db, admin_user.id, "admin-api-key")
print(f"✓ Admin API Key: {api_key.key}")
_, raw_key = create_api_key(db, admin_user.id, "admin-api-key")
print(f"✓ Admin API Key: {raw_key}")
else:
print("✗ Admin user already exists")
db.close()

View File

@ -0,0 +1 @@
"""Tests for the Ollama Proxy."""

Binary file not shown.

87
backend/tests/conftest.py Normal file
View File

@ -0,0 +1,87 @@
import pytest
from fastapi.testclient import TestClient
import tempfile
import os
from pathlib import Path
def create_test_db():
"""Create a temporary SQLite database for tests."""
temp_db = tempfile.NamedTemporaryFile(suffix='.db', delete=False)
temp_db.close()
os.environ["DATABASE_URL"] = f"sqlite:///{temp_db.name}"
return temp_db.name
def cleanup_test_db(db_path):
"""Remove the temporary database."""
if os.path.exists(db_path):
os.unlink(db_path)
os.environ.pop("DATABASE_URL", None)
def setup_test_db():
"""Setup test database with required data."""
from database import Base, engine, SessionLocal
from models import User, APIKey, Quota, Usage
from crud import create_api_key, hash_password
# Create tables
Base.metadata.create_all(bind=engine)
db = SessionLocal()
# Create test user
test_user = User(
username="testuser",
email="test@example.com",
hashed_password=hash_password("test123"),
is_active=True
)
db.add(test_user)
db.commit()
db.refresh(test_user)
# Create API key for test user
_, raw_key = create_api_key(db, test_user.id, "test-key")
os.environ["TEST_API_KEY"] = raw_key
# Create admin user
admin_user = User(
username="admin",
email="admin@example.com",
hashed_password=hash_password("admin123"),
is_active=True,
is_admin=True,
)
db.add(admin_user)
db.commit()
db.refresh(admin_user)
# Create admin API key
_, admin_raw_key = create_api_key(db, admin_user.id, "admin-key")
os.environ["ADMIN_API_KEY"] = admin_raw_key
db.close()
return raw_key, admin_raw_key
def teardown_test_db():
"""Clean up test database and environment."""
from database import engine
from models import Base
Base.metadata.drop_all(bind=engine)
os.environ.pop("TEST_API_KEY", None)
os.environ.pop("ADMIN_API_KEY", None)
@pytest.fixture(scope="session")
def test_client():
"""Create test client with test database."""
db_path = create_test_db()
setup_test_db()
from main import app
client = TestClient(app)
yield client
teardown_test_db()
cleanup_test_db(db_path)

106
backend/tests/test_auth.py Normal file
View File

@ -0,0 +1,106 @@
import pytest
import os
from unittest.mock import AsyncMock, patch
from fastapi.testclient import TestClient
from main import app
from database import Base, engine, SessionLocal
from models import User, APIKey, Quota
from crud import create_api_key, hash_password
os.environ["OLLAMA_URL"] = "http://127.0.0.1:9999"
def setup_test_db():
"""Setup test database with required data."""
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
db = SessionLocal()
test_user = User(
username="testuser",
email="test@example.com",
hashed_password=hash_password("test123"),
is_active=True
)
db.add(test_user)
db.commit()
db.refresh(test_user)
quota = Quota(
user_id=test_user.id,
daily_tokens=1000000,
monthly_tokens=10000000,
daily_requests=1000,
monthly_requests=10000
)
db.add(quota)
db.commit()
api_key_record, raw_key = create_api_key(db, test_user.id, "test-key")
os.environ["TEST_API_KEY"] = raw_key
admin_user = User(
username="admin",
email="admin@example.com",
hashed_password=hash_password("admin123"),
is_active=True,
is_admin=True,
)
db.add(admin_user)
db.commit()
db.refresh(admin_user)
admin_quota = Quota(
user_id=admin_user.id,
daily_tokens=10000000,
monthly_tokens=100000000,
daily_requests=10000,
monthly_requests=100000
)
db.add(admin_quota)
db.commit()
_, admin_raw_key = create_api_key(db, admin_user.id, "admin-key")
os.environ["ADMIN_API_KEY"] = admin_raw_key
db.close()
return os.environ["TEST_API_KEY"], os.environ["ADMIN_API_KEY"]
def teardown_test_db():
"""Clean up test database and environment."""
Base.metadata.drop_all(bind=engine)
os.environ.pop("TEST_API_KEY", None)
os.environ.pop("ADMIN_API_KEY", None)
@pytest.fixture(scope="function")
def test_client():
setup_test_db()
client = TestClient(app, raise_server_exceptions=False)
yield client
teardown_test_db()
def test_auth_middleware_missing_auth(test_client):
response = test_client.post("/api/generate", json={"model": "llama3", "prompt": "test"})
assert response.status_code == 401
def test_auth_middleware_invalid_key(test_client):
response = test_client.post(
"/api/generate",
headers={"Authorization": "sk-invalid-key"},
json={"model": "llama3", "prompt": "test"}
)
assert response.status_code == 401
@patch("main.proxy_request", new_callable=AsyncMock)
def test_auth_middleware_valid_key(mock_proxy, test_client):
mock_proxy.return_value.status_code = 200
mock_proxy.return_value.json = lambda: {"response": "success"}
mock_proxy.return_value.headers = {}
response = test_client.post(
"/api/generate",
headers={"Authorization": os.environ.get("TEST_API_KEY", "")},
json={"model": "llama3", "prompt": "test"}
)
assert response.status_code == 200

182
backend/tests/test_quota.py Normal file
View File

@ -0,0 +1,182 @@
import pytest
import os
from datetime import datetime, timedelta
os.environ.setdefault("OLLAMA_URL", "http://127.0.0.1:9999")
from database import Base, engine, SessionLocal
from models import User, Quota, Usage
from crud import check_and_increment_quota, count_tokens, hash_password
def make_user_and_quota(db, daily_tokens=None, monthly_tokens=None,
daily_requests=None, monthly_requests=None):
user = User(
username="quotauser",
email="quota@example.com",
hashed_password=hash_password("pass"),
is_active=True,
)
db.add(user)
db.commit()
db.refresh(user)
quota = Quota(
user_id=user.id,
daily_tokens=daily_tokens,
monthly_tokens=monthly_tokens,
daily_requests=daily_requests,
monthly_requests=monthly_requests,
)
db.add(quota)
db.commit()
return user.id
@pytest.fixture
def db():
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
session = SessionLocal()
yield session
session.close()
Base.metadata.drop_all(bind=engine)
# --- count_tokens ---
def test_count_tokens_empty():
assert count_tokens("") == 0
def test_count_tokens_returns_int():
assert isinstance(count_tokens("hello world"), int)
def test_count_tokens_scales_with_length():
assert count_tokens("a b c d e f g h") > count_tokens("a")
def test_count_tokens_more_accurate_than_split():
# tiktoken counts "don't" as 2 tokens, split counts it as 1 word
text = "don't do that"
assert count_tokens(text) >= len(text.split())
# --- check_and_increment_quota ---
def test_allowed_within_daily_token_limit(db):
user_id = make_user_and_quota(db, daily_tokens=1000)
assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is True
def test_denied_when_daily_tokens_exceeded(db):
user_id = make_user_and_quota(db, daily_tokens=50)
assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False
def test_denied_when_monthly_tokens_exceeded(db):
user_id = make_user_and_quota(db, monthly_tokens=50)
assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False
def test_denied_when_daily_requests_exceeded(db):
user_id = make_user_and_quota(db, daily_requests=1)
check_and_increment_quota(db, user_id, tokens=0, requests=1)
assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False
def test_denied_when_monthly_requests_exceeded(db):
user_id = make_user_and_quota(db, monthly_requests=1)
check_and_increment_quota(db, user_id, tokens=0, requests=1)
assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False
def test_increments_both_daily_and_monthly_counters(db):
user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000,
daily_requests=100, monthly_requests=1000)
check_and_increment_quota(db, user_id, tokens=50, requests=1)
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
assert usage.tokens_used_today == 50
assert usage.tokens_used_month == 50
assert usage.requests_today == 1
assert usage.requests_month == 1
def test_creates_usage_record_on_first_call(db):
user_id = make_user_and_quota(db, daily_tokens=1000)
assert db.query(Usage).filter(Usage.user_id == user_id).first() is None
check_and_increment_quota(db, user_id, tokens=10, requests=1)
assert db.query(Usage).filter(Usage.user_id == user_id).first() is not None
def test_no_quota_allows_any_request(db):
user_id = make_user_and_quota(db) # all limits None
assert check_and_increment_quota(db, user_id, tokens=999999, requests=9999) is True
def test_cumulative_usage_across_calls(db):
user_id = make_user_and_quota(db, daily_tokens=200)
check_and_increment_quota(db, user_id, tokens=100, requests=1)
check_and_increment_quota(db, user_id, tokens=99, requests=1)
# 199 used, 1 remaining exactly 1 more token should pass
assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is True
# Now 200 used next request must fail
assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is False
# --- Reset logic ---
def test_daily_reset_restores_access(db):
user_id = make_user_and_quota(db, daily_tokens=100)
check_and_increment_quota(db, user_id, tokens=90, requests=1)
# Backdate daily_reset_at to yesterday
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
usage.daily_reset_at = datetime.utcnow() - timedelta(days=1)
db.commit()
# Should pass again after reset
assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
assert usage.tokens_used_today == 90
def test_daily_reset_does_not_affect_monthly_counter(db):
user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000)
check_and_increment_quota(db, user_id, tokens=50, requests=1)
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
usage.daily_reset_at = datetime.utcnow() - timedelta(days=1)
db.commit()
check_and_increment_quota(db, user_id, tokens=50, requests=1)
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
assert usage.tokens_used_today == 50
assert usage.tokens_used_month == 100 # cumulative across days
def test_monthly_reset_restores_access(db):
user_id = make_user_and_quota(db, monthly_tokens=100)
check_and_increment_quota(db, user_id, tokens=90, requests=1)
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
usage.monthly_reset_at = datetime.utcnow() - timedelta(days=32)
db.commit()
assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
assert usage.tokens_used_month == 90
def test_failed_quota_check_still_commits_reset(db):
user_id = make_user_and_quota(db, daily_tokens=100, daily_requests=5)
check_and_increment_quota(db, user_id, tokens=80, requests=1)
# Backdate so a reset fires, but the new request still exceeds the limit
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
usage.daily_reset_at = datetime.utcnow() - timedelta(days=1)
usage.tokens_used_today = 80
db.commit()
# After reset tokens_used_today = 0; 200 tokens exceeds 100 limit
result = check_and_increment_quota(db, user_id, tokens=200, requests=1)
assert result is False
# Reset must still be persisted so the next request sees fresh counters
db.expire_all()
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
assert usage.tokens_used_today == 0