From bf694b79e209a352a45150591250baf9c308fa02 Mon Sep 17 00:00:00 2001 From: Oliver Hofmann Date: Mon, 27 Apr 2026 21:34:17 +0200 Subject: [PATCH] Fix critical/high security and correctness issues from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical (all fixed): - bcrypt statt SHA-256 für Passwörter - API-Keys gehasht in DB, Plaintext nur einmalig zurückgegeben - DB-Session-Leak behoben (SessionLocal + try/finally, Depends(get_db)) - Admin-Check via is_admin-Spalte statt Hardcoded-Username - CORS: konfigurierbare Origins via ALLOWED_ORIGINS, kein Wildcard mit Credentials High (all fixed): - TOCTOU-Race: check_and_increment_quota mit SELECT FOR UPDATE atomar - Getrennte Tages-/Monatszähler in Usage + automatische Reset-Logik - Token-Zählung mit tiktoken (cl100k_base) statt .split() Co-Authored-By: Claude Sonnet 4.6 --- backend/admin.py | 49 ++--- backend/crud.py | 114 ++++++----- backend/main.py | 100 +++++----- backend/models.py | 10 +- backend/requirements.txt | 3 +- backend/schemas.py | 16 +- backend/setup_admin.py | 7 +- backend/tests/__init__.py | 1 + .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 213 bytes .../conftest.cpython-312-pytest-8.3.4.pyc | Bin 0 -> 3894 bytes .../test_auth.cpython-312-pytest-8.3.4.pyc | Bin 0 -> 6979 bytes .../test_quota.cpython-312-pytest-8.3.4.pyc | Bin 0 -> 33897 bytes backend/tests/conftest.py | 87 +++++++++ backend/tests/test_auth.py | 106 ++++++++++ backend/tests/test_quota.py | 182 ++++++++++++++++++ 15 files changed, 547 insertions(+), 128 deletions(-) create mode 100644 backend/tests/__init__.py create mode 100644 backend/tests/__pycache__/__init__.cpython-312.pyc create mode 100644 backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc create mode 100644 backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc create mode 100644 backend/tests/__pycache__/test_quota.cpython-312-pytest-8.3.4.pyc create mode 100644 backend/tests/conftest.py create mode 100644 backend/tests/test_auth.py create mode 100644 backend/tests/test_quota.py diff --git a/backend/admin.py b/backend/admin.py index 100e5c7..9bb53f8 100644 --- a/backend/admin.py +++ b/backend/admin.py @@ -1,3 +1,4 @@ +import os from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.orm import Session @@ -7,15 +8,17 @@ from models import User, APIKey, Quota app = FastAPI(title="Ollama Proxy Admin API") +ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173").split(",") + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=ALLOWED_ORIGINS, allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["Authorization", "Content-Type"], ) -async def require_admin_auth(request: Request): +async def require_admin_auth(request: Request, db: Session = Depends(get_db)): auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): api_key = auth_header.replace("Bearer ", "") @@ -23,24 +26,21 @@ async def require_admin_auth(request: Request): api_key = auth_header else: raise HTTPException(status_code=401, detail="Invalid or missing API key") - - db = next(get_db()) + db_key = crud.verify_api_key(db, api_key) - if not db_key: raise HTTPException(status_code=401, detail="Invalid API key") - + db_user = db.query(User).filter(User.id == db_key.user_id).first() - if not db_user or db_user.username != "admin": + if not db_user or not db_user.is_admin: raise HTTPException(status_code=403, detail="Admin access required") - + request.state.user = db_user - request.state.db = db @app.get("/api/users", response_model=list[schemas.User]) async def read_users( - skip: int = 0, - limit: int = 100, + skip: int = 0, + limit: int = 100, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -49,7 +49,7 @@ async def read_users( @app.post("/api/users", response_model=schemas.User) async def create_user( - user: schemas.UserCreate, + user: schemas.UserCreate, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -61,21 +61,24 @@ async def create_user( raise HTTPException(status_code=400, detail="Email already registered") return crud.create_user(db=db, username=user.username, email=user.email, password=user.password) -@app.post("/api/api-keys", response_model=schemas.APIKey) +@app.post("/api/api-keys", response_model=schemas.APIKeyCreated) async def create_api_key( - api_key: schemas.APIKeyCreate, + api_key: schemas.APIKeyCreate, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): db_user = db.query(User).filter(User.id == api_key.user_id).first() if not db_user: raise HTTPException(status_code=404, detail="User not found") - return crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name) + db_key, raw_key = crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name) + result = schemas.APIKeyCreated.model_validate(db_key) + result.plaintext_key = raw_key + return result @app.get("/api/api-keys", response_model=list[schemas.APIKey]) async def read_api_keys( - skip: int = 0, - limit: int = 100, + skip: int = 0, + limit: int = 100, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -84,7 +87,7 @@ async def read_api_keys( @app.put("/api/api-keys/{api_key_id}/deactivate") async def deactivate_api_key( - api_key_id: int, + api_key_id: int, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): @@ -97,15 +100,15 @@ async def deactivate_api_key( @app.put("/api/quotas/{user_id}", response_model=schemas.Quota) async def update_quota( - user_id: int, - quota: schemas.QuotaCreate, + user_id: int, + quota: schemas.QuotaCreate, db: Session = Depends(get_db), _ = Depends(require_admin_auth) ): db_quota = db.query(Quota).filter(Quota.user_id == user_id).first() if not db_quota: raise HTTPException(status_code=404, detail="Quota not found") - for key, value in quota.dict(exclude_unset=True).items(): + for key, value in quota.model_dump(exclude_unset=True).items(): setattr(db_quota, key, value) db.commit() db.refresh(db_quota) diff --git a/backend/crud.py b/backend/crud.py index 68650f8..fb17afe 100644 --- a/backend/crud.py +++ b/backend/crud.py @@ -1,10 +1,25 @@ import secrets import hashlib -from datetime import datetime, timedelta +import bcrypt +import tiktoken +from datetime import datetime from sqlalchemy.orm import Session -from database import get_db from models import APIKey, User, Quota, Usage +_encoder = tiktoken.get_encoding("cl100k_base") + +def count_tokens(text: str) -> int: + return len(_encoder.encode(text)) + +def hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + +def verify_password(plain_password: str, hashed_password: str) -> bool: + return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) + +def _hash_api_key(key: str) -> str: + return hashlib.sha256(key.encode()).hexdigest() + def get_user_by_username(db: Session, username: str): return db.query(User).filter(User.username == username).first() @@ -14,19 +29,17 @@ def get_user_by_email(db: Session, email: str): def generate_api_key(): return "sk-" + secrets.token_urlsafe(32) -def hash_password(password: str): - return hashlib.sha256(password.encode()).hexdigest() - -def create_user(db: Session, username: str, email: str, password: str): +def create_user(db: Session, username: str, email: str, password: str, is_admin: bool = False): db_user = User( username=username, email=email, - hashed_password=hash_password(password) + hashed_password=hash_password(password), + is_admin=is_admin, ) db.add(db_user) db.commit() db.refresh(db_user) - + default_quota = Quota( user_id=db_user.id, daily_tokens=1000000, @@ -36,60 +49,75 @@ def create_user(db: Session, username: str, email: str, password: str): ) db.add(default_quota) db.commit() - + return db_user -def create_api_key(db: Session, user_id: int, name: str): - key = generate_api_key() +def create_api_key(db: Session, user_id: int, name: str) -> tuple[APIKey, str]: + raw_key = generate_api_key() db_key = APIKey( name=name, - key=key, + key=_hash_api_key(raw_key), user_id=user_id ) db.add(db_key) db.commit() db.refresh(db_key) - return db_key + return db_key, raw_key def verify_api_key(db: Session, api_key: str): - return db.query(APIKey).filter(APIKey.key == api_key, APIKey.is_active == True).first() + key_hash = _hash_api_key(api_key) + return db.query(APIKey).filter(APIKey.key == key_hash, APIKey.is_active == True).first() def get_quota(db: Session, user_id: int): return db.query(Quota).filter(Quota.user_id == user_id).first() -def check_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1): - quota = get_quota(db, user_id) - usage = get_usage(db, user_id) - - if quota.daily_tokens and (usage.tokens_used + tokens) > quota.daily_tokens: - return False - if quota.monthly_tokens and (usage.tokens_used + tokens) > quota.monthly_tokens: - return False - if quota.daily_requests and (usage.requests_count + requests) > quota.daily_requests: - return False - if quota.monthly_requests and (usage.requests_count + requests) > quota.monthly_requests: - return False - - return True - -def increment_usage(db: Session, user_id: int, tokens: int = 0, requests: int = 1): - usage = get_or_create_usage(db, user_id) - usage.tokens_used += tokens - usage.requests_count += requests - db.commit() - db.refresh(usage) +def get_quota_by_user_id(db: Session, user_id: int): + return db.query(Quota).filter(Quota.user_id == user_id).first() def get_usage(db: Session, user_id: int): return db.query(Usage).filter(Usage.user_id == user_id).first() -def get_quota_by_user_id(db: Session, user_id: int): - return db.query(Quota).filter(Quota.user_id == user_id).first() - -def get_or_create_usage(db: Session, user_id: int): - usage = get_usage(db, user_id) +def check_and_increment_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1) -> bool: + usage = ( + db.query(Usage) + .filter(Usage.user_id == user_id) + .with_for_update() + .first() + ) if not usage: usage = Usage(user_id=user_id) db.add(usage) - db.commit() - db.refresh(usage) - return usage + db.flush() + + now = datetime.utcnow() + + if usage.daily_reset_at.date() < now.date(): + usage.tokens_used_today = 0 + usage.requests_today = 0 + usage.daily_reset_at = now + + if (usage.monthly_reset_at.year, usage.monthly_reset_at.month) < (now.year, now.month): + usage.tokens_used_month = 0 + usage.requests_month = 0 + usage.monthly_reset_at = now + + quota = get_quota(db, user_id) + allowed = True + if quota: + if quota.daily_tokens and (usage.tokens_used_today + tokens) > quota.daily_tokens: + allowed = False + elif quota.monthly_tokens and (usage.tokens_used_month + tokens) > quota.monthly_tokens: + allowed = False + elif quota.daily_requests and (usage.requests_today + requests) > quota.daily_requests: + allowed = False + elif quota.monthly_requests and (usage.requests_month + requests) > quota.monthly_requests: + allowed = False + + if allowed: + usage.tokens_used_today += tokens + usage.tokens_used_month += tokens + usage.requests_today += requests + usage.requests_month += requests + + db.commit() + return allowed \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index ab36f1e..5a458d5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,7 +1,9 @@ +import time +import uuid from fastapi import FastAPI, HTTPException, Depends, Request from fastapi.responses import JSONResponse from sqlalchemy.orm import Session -from database import get_db +from database import get_db, SessionLocal import crud import httpx import os @@ -22,54 +24,48 @@ async def authenticate_and_quota(request: Request, call_next): elif auth_header.startswith("sk-"): api_key = auth_header else: - raise HTTPException(status_code=401, detail="Invalid or missing API key") - - db = next(get_db()) - db_key = crud.verify_api_key(db, api_key) - - if not db_key: - raise HTTPException(status_code=401, detail="Invalid API key") - - if not db_key.is_active: - raise HTTPException(status_code=403, detail="API key deactivated") - - request.state.user_id = db_key.user_id - + return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"}) + + db = SessionLocal() + try: + db_key = crud.verify_api_key(db, api_key) + if not db_key: + return JSONResponse(status_code=401, content={"detail": "Invalid API key"}) + if not db_key.is_active: + return JSONResponse(status_code=403, content={"detail": "API key deactivated"}) + request.state.user_id = db_key.user_id + finally: + db.close() + response = await call_next(request) return response @app.post("/api/generate") -async def generate(request: Request): - db = next(get_db()) +async def generate(request: Request, db: Session = Depends(get_db)): user_id = request.state.user_id - + body = await request.json() - - prompt_tokens = len(body.get("prompt", "").split()) - if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1): + + prompt_tokens = crud.count_tokens(body.get("prompt", "")) + if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") - + response = await proxy_request(f"{OLLAMA_URL}/api/generate", method="POST", json_data=body, headers=dict(request.headers)) - - crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1) - + return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.post("/api/chat") -async def chat(request: Request): - db = next(get_db()) +async def chat(request: Request, db: Session = Depends(get_db)): user_id = request.state.user_id - + body = await request.json() - - prompt_tokens = sum(len(msg.get("content", "").split()) for msg in body.get("messages", [])) - if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1): + + prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in body.get("messages", [])) + if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") - + response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=body, headers=dict(request.headers)) - - crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1) - + return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.get("/api/tags") @@ -101,48 +97,48 @@ async def list_openai_models(request: Request): return JSONResponse(content=openai_models, status_code=200, headers=dict(response.headers)) @app.post("/v1/chat/completions") -async def openai_chat_completions(request: Request): - db = next(get_db()) +async def openai_chat_completions(request: Request, db: Session = Depends(get_db)): user_id = request.state.user_id - + body = await request.json() - + messages = body.get("messages", []) - prompt_tokens = sum(len(msg.get("content", "").split()) for msg in messages) - - if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1): + prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in messages) + + if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") - + ollama_body = { "model": body.get("model", "llama3"), "messages": messages, "stream": body.get("stream", False) } - + response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=ollama_body, headers=dict(request.headers)) - - crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1) - + + response_content = response.json().get("message", {}).get("content", "") + completion_tokens = crud.count_tokens(response_content) + openai_response = { - "id": f"chatcmpl-{hash(msg.get('content', ''))}", + "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", - "created": int(__import__('time').time()), + "created": int(time.time()), "model": body.get("model", "llama3"), "choices": [ { "index": 0, "message": { "role": "assistant", - "content": response.json().get("message", {}).get("content", "") + "content": response_content }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": prompt_tokens, - "completion_tokens": len(response.json().get("message", {}).get("content", "").split()), - "total_tokens": prompt_tokens + len(response.json().get("message", {}).get("content", "").split()) + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens } } - + return JSONResponse(content=openai_response, status_code=200, headers={"Content-Type": "application/json"}) diff --git a/backend/models.py b/backend/models.py index a9fd804..071955f 100644 --- a/backend/models.py +++ b/backend/models.py @@ -10,6 +10,7 @@ class User(Base): email = Column(String, unique=True, index=True) hashed_password = Column(String) is_active = Column(Boolean, default=True) + is_admin = Column(Boolean, default=False) created_at = Column(DateTime, default=datetime.utcnow) class APIKey(Base): @@ -38,6 +39,9 @@ class Usage(Base): id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, ForeignKey("users.id"), unique=True) - tokens_used = Column(BigInteger, default=0) - requests_count = Column(Integer, default=0) - reset_at = Column(DateTime, default=datetime.utcnow) + tokens_used_today = Column(BigInteger, default=0) + tokens_used_month = Column(BigInteger, default=0) + requests_today = Column(Integer, default=0) + requests_month = Column(Integer, default=0) + daily_reset_at = Column(DateTime, default=datetime.utcnow) + monthly_reset_at = Column(DateTime, default=datetime.utcnow) diff --git a/backend/requirements.txt b/backend/requirements.txt index 22ffceb..66f0d51 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,7 +6,8 @@ alembic==1.14.0 pydantic==2.10.3 python-multipart==0.0.20 python-jose[cryptography]==3.3.0 -passlib[bcrypt]==1.7.4 +bcrypt==5.0.0 +tiktoken==0.9.0 python-dotenv==1.0.1 pytest==8.3.4 pytest-asyncio==0.25.1 diff --git a/backend/schemas.py b/backend/schemas.py index 55cdaec..cc3cd1d 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -5,6 +5,7 @@ from typing import Optional class UserBase(BaseModel): username: str email: str + is_admin: bool = False class UserCreate(UserBase): password: str @@ -33,6 +34,12 @@ class APIKey(APIKeyBase): class Config: from_attributes = True +class APIKeyCreated(APIKey): + plaintext_key: str + + class Config: + from_attributes = True + class QuotaBase(BaseModel): daily_tokens: Optional[int] = None monthly_tokens: Optional[int] = None @@ -51,9 +58,12 @@ class Quota(QuotaBase): from_attributes = True class UsageStats(BaseModel): - tokens_used: int = 0 - requests_count: int = 0 - last_reset: Optional[datetime] = None + tokens_used_today: int = 0 + tokens_used_month: int = 0 + requests_today: int = 0 + requests_month: int = 0 + daily_reset_at: Optional[datetime] = None + monthly_reset_at: Optional[datetime] = None class Config: from_attributes = True diff --git a/backend/setup_admin.py b/backend/setup_admin.py index daedf2c..6ad72f1 100644 --- a/backend/setup_admin.py +++ b/backend/setup_admin.py @@ -13,7 +13,8 @@ def setup_admin(): username="admin", email="admin@ollama.local", hashed_password=hash_password("admin123"), - is_active=True + is_active=True, + is_admin=True, ) db.add(admin_user) db.commit() @@ -31,8 +32,8 @@ def setup_admin(): db.commit() print("✓ Admin quota created") - api_key = create_api_key(db, admin_user.id, "admin-api-key") - print(f"✓ Admin API Key: {api_key.key}") + _, raw_key = create_api_key(db, admin_user.id, "admin-api-key") + print(f"✓ Admin API Key: {raw_key}") else: print("✗ Admin user already exists") db.close() diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..73e3b1b --- /dev/null +++ b/backend/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the Ollama Proxy.""" diff --git a/backend/tests/__pycache__/__init__.cpython-312.pyc b/backend/tests/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2751027f32f28e0edf82d8d47f7c8fa44cf42f0f GIT binary patch literal 213 zcmXwzF%AJi7=>pLHWDY$ES<(&KvXIP3L4FywfkdYc7A4tSeH;bgR{7S)&WR#Dl-YM zdGBl9d$&fT&b&tFr)+zF1o(^p!aMTe!ZJ2z1@Dwj%_&-E<0YBctuY=emE=i0F?1}t zqYl5<05VDeuIwEeK0`6HXsjS54`{>y52$FHqjqA7=~|Pbl|(I+%3-%Bmxx7@Zct~! lg&_eTwR8a8ZSjeEIc0l=eN*+-;h_!hUS-btowZEG!xw-jJWBuo literal 0 HcmV?d00001 diff --git a/backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc b/backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e95f4ef4239f57d1299af1f2f97c7446c7d44a3a GIT binary patch literal 3894 zcmbVP-ESMm5#Kv5k34=!qAgqgk}12jeQqSic3v8|sbyPsB+JT&#82ahXd;#v?EWjHnqKj{y(Qq4BsDgg&83(6?zJ=-Wj^q}(yz z{N|{Im*pkFn`@CdwA5gYCuamD`cF0+Q3N+KMr`Mjp_9A=GyJaVOl2FAe8AANGi^MC z?g4Td8WGBgoWG^h1WjBHEHyY4R7Sbr76Cr-zl5C$B&3}ZzeT?j9k^Mq<@jy;GjG&_ z=YiOL_b5`NQtu^7REMP0ltavdMO8YJ8hvkAhb2vQ)CtulsY#0h*LLi_QlL+pxT1)z zWEUqV^}DX55ra6S6tQm`u;>nM-+<1Rjk=!uTJgFPRO^&E0(Z+m3 zSJ+xGYgjgMaUNK%XxVOv*v^bLVYAyzeLs}GVH0YnEknOUX!wiyC2X@ZTuom{EGL zw3X-4z7@{z*~xF!0QN&=xg6OQ8KX<=&Y*n&w^oT(V&$0gBX0$Q=Y&sfj#`YN4L@64 zOYpO``CqHaaJu%T2F_LD|K)UXU;YQ5u0=!Rbn#tJ*Rkz%iE`q9oG!_+RL&`zC-$k9 zJCBtUlOka2p3=!t;uH%UI-BW|n%14E6eYKdIwcxk?xU;_C;=B=1hj?R$=}rT#Ep#- z+t#i8u$5H}1-qDmpDXphJ9v%ExWV@rz`4O2wtABQAp3Zx7WB+5;K!#_dkXT*wx=zs zJ%W-oq31P*#kNxfMWuE|T_AT=vtW?EtYtE^wVg9(&y9hxFq7es93-Zy8*UrZN3^C! zSJv%}nsoq)OI&F$vl<&47#+*NCuXh<{L&4onyKeY9emhgD$JL_mJ+0=;P8}PDLu`) zpFrn^nR32a<^xwA@4tL~aKzWHbh=@`LQ$4|^f2pK#<-GcL0z(GH(OyDLQ||`x!~d~ zEo$@_n;&QMut^-1DRLu?lAs!f8&EaPmEcxP-El*dOj2S`xuULdl;e0tqp)3R+h%(= zn=RHvZNim#&0!bp3K>wYPIHe%_(?N-J?_yBmfZ%HQd^scn{k^(SQ9lbfZSgA*?XY& zRMFq@xmA4b89uXy&pg9t*CiB`Sq=LNA6~_W7f-C&JNC9V+5;yGc|kw_9s(| z;^N>Rhkrl(#p$(WISo>MI~l1c5WcXZVgJ4v3Wui~T6;oIJxeS2=` zc{I6jeI?qnf_rur1_rjgFw}r(Eewr<08i1108k|4{?GbZM$32;!kieW-&#)DqYG)b6Gw%(cARDPw$UyS6q{x1@NA$#*q|J~nmGXiY5J;t(HJxw9?>Cj>BHi9$du-7o`<^u!c{**jn^!_*CmN;OC2%A7h1k+YQ+sjKa$! zXlJbg$%5`?W4=EQzm5CQT|>f_r@Ig87345E-uyKlAEMES4@dU-Wx@A;MJ$V-3J)N9 z9EKFUZSHp1Yl*kk_u5S0Pl0Jc!KD+8XE}q4RF?nk@G1vi@Nv&~8NgMcFrtL`E0$gW zwJrwK&eN+vQugii66@HjN(oT*e&H-YIIh}eA}QZ5QI1a@hYkVf8vJbL|rK>_`2g zpzLwO6Gh!{^t|oHE)-@m&DnuWVTPIQgMvk<17)v|l4%NmDO|a^#DLgai@;|7l`2$3 zk93ohs_j7e>tl?@*Yl$@-*}JH5PNqT`UGWZ?_CzV5`JU+hTIVRsd0*w@N+*xJcM|( zejh#$MWtwvx9Bfm3*alVhoM^!2!ikop*i%M_*0~OjoQ9O@>A6H5(kB@mkBH=FJrF> m$6kIQgoM-UNhG#c_sw_Db}#5PA@!XY6^?8Kkl4d1ZtXv<<5-sf literal 0 HcmV?d00001 diff --git a/backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc b/backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a0c1154875f4cd40831227fe16e719351af46a7 GIT binary patch literal 6979 zcmd5=U2GKB6`q-${r~Z<@h^7lg%r$^V)G9nKw}3JFyN#C+Gdk>I^H|RV`hJHXN>W7 z-L6wB5cMJCB~(?UIucT_AS!)o9(k%NwGZCni0zaJs`?>$D#5AL^3-$g?9494#*LIJ zz4o5D_uO;O{XOS=_x>>)4seie{^}p9Kg4l=$AX>2N@lT%=eWxp;RvsBNuGU$v@0d> zTv|-Jd8@{g^spLl(hFZf^QHYse>#v1uzHskOox&omKU{fI+BdAyj!bFN0ZTXEE!AJ zC+iu{qpe9dBpcF=$;Nb35;_NcUadLZl59!0CR?HGQ@o=)?O|5G#J?X3Brwe-+ei>- zyTB=)(S213J4Zs-!`EO|m`G%ryJls%WCy8(cAX>&G!Cr;B&PVU*IyGX0kWn>cHIP@ zX!sl}1`~qm@6{(VsT0}MS<{`9jnqhjHv^{>-FQJ$m5c!;mz>KbT&CD3>x$`7GH29` zVunsCx~^t3$FnI}!`jojLQPNa%SVqX6Q=vSV_8E6?U58!WJ8hUoGP7#itvc6k4QOL z*T=JzaHzWdBbD#uJiUf@~i3b+ij}f74!sc3432AIU{WIh(KIMol_F+{)zj# z=UU~=PVxnAG-k_GzlEC7g2b9?jh(Z52=u8N!O%U5RG<=1N+%w=c->4M$4&L z;|Qs%S>vcZI^u>^jybb&uP__8Ex-C~R#@d44K=ImUUrpzq@D!GnwnMitS~Ome?P8A zR(bh;8y;C@@1(EbTV<6Sag||g>#8en=c?y!6&>ML+^Hw;W+TcB3n#fmQ@-n@VvObD z26&=4kqvnWyia^wHAdo8d1Fka3SsQ-S>6Qh~t$^zp6`e%23ZK>hpbdoLYSEmz??;98ERa z#c>I@=|vTiO3V-e1rw5yJ*#ANGm_3`jFFnMRg2?=?sYS2*V~*t8}k-)jG zKBZ&*ubOU|q}5D5#y*eC#K%}Lm4ut3nbRgDXwVcb+4SMqv9Zei?VoOS<}IGf`x%3E z6Yf29;^?5=Pa;OqrJBC9V#v72rjO8UPLef^`miEkyKi8i8>AR@X#}fWG9ji17B{UL zrk5(iRMAH$#A41AR6^11nL<`Ky-MbsO0yZ$ozk+plJJ`X88ZF2Arda5>3#!ag&C=Y z2#G2w2o0c89%msVz(#M$2Ccl+f|ty&Rj!EYIK;S}9v9mT$dXhS*H zR*uB)v}~ABZ?$YK*EfCXcgH-%qxVCc+k4^g)ZvTI&Wnv@(R1PG)Y0iHNFxHvH{ zwl0KXV|%%|v)tTPZteIo=m~p@{r4kw`(4u~oaPH1-@f|Fjn?0_|F-?( z{H>0GlGs)py!2vu-KOb2ruU~Mv11|DSdO)nBkOQfqqkbN!KliyMo?VuDGpQ>@0%B! z9L@H?^7G=_h5F`lbnWHvyWw)v?0b9PrA#Tf;R^jQ|3Usn zPpPYKF4#9G_ASDWa495q!nletW${lyzDC4=*1^&!!=ZMm!gFGNm5=w4*(Mg}-knUd^y#yS5O5Dvpne-6ls=bxR}3xiXGmk!SCoEJCUvCc5#d}rp+ zCt`wiV)};1GT?8s89~xIho6oc@EP|nT(a2oOLxM}@4IfCo)14)JX{V$ z-#mUtZ1_ZM`nQe~c=MSpe9D<8FdX=M6S$W#AlKG{04BoYQ6u8}!gaSZa~=Q}K&qu_ zQXrm5S3xMa&Taz01#nCM$Z%lur~{U(?TEJ^4)er^ZH5IKwi3VLfaK*2!d~TboZeWY zA06nDRE4RuRuUu5l2nz<@X&vC)XT{_Fm)9dK(M(T{X45&miHW0nHuL$rubnw*i*6lcB6rszws*k%y> zvIf4bwy?9du)``?s&>~F_UdR^72IGM?>cK{5d5XIN{FV6!<--scizky1`6mIXmb*N z`a3|1T)A%T?YhoVUFS^QP2b$9WNG(nb9J4!>RvA%DF@fw4tAD;opa*GnH@mpf}Jxv zO2Lh{Mg0Gy!mvsXEr}a1?zKv{s@ZuaY_b20>Ku6|10u|=F7G>jzpC$0c_PAv$paq5 zYhM-MV!DzC@qI-P;%6Qtu*8D|zhMs&tg+{>$Ag3%4-zJk|A_~w8>ENe?3v+SIHy^v z=9y*WBl_7bYUZ4*sRS`U?FSi(!J57Z1pJD3M3D&~V>*D$V@NPI(Lp3=Q1oRar;waR z@;xLnkjD*$UC@|8)EY{qd&KWlgLdD?w!a2aV<%-_b2%2DTfcoSx&tum)!n7=_WNEg z8einZsHb@3et-)#d=qw)uZL@4rW?HO`3dk-?D7WAHe8v(mKx=f56bsnK?DL@Ti z?>1Th`)3b8VE+Nf{)K|DEP`+mh-U>?L1YNRU2p?}aAM@>GLex43vNIVA#78FAi{?8 zXf=vg+9l}`1X0z$%;QzTiS>XWs*-D$D1>5F5)edq64$gn*+YbO_h{y-?TbGT0%%`AHckj32`9a87Ervjy9t2Y4 zt`2?J@9$>t}mjT;GOffT}{mMJHHSDd>2aXVmXEm{{zr7v||7O literal 0 HcmV?d00001 diff --git a/backend/tests/__pycache__/test_quota.cpython-312-pytest-8.3.4.pyc b/backend/tests/__pycache__/test_quota.cpython-312-pytest-8.3.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed40ea5c7e3f244add85e09f4bf8a426436b056b GIT binary patch literal 33897 zcmeHQeQX;?cHbqJAL2@sWm*1qB7enZF0n27JBjVsapL%#NUDo81c0^^ne;@(KAq>aM7X=E!i?y7YWbxh5q?&iN3nsK+-mZ)S(Pvv1ye^WN+~H8sUF9KZUr|IGaPd`!p#=vu$g6ViD}xF^g?B0Uj&2eQ%rXiv1ip{GID(viMd zqEYqTH%H+Dvx1iKVSuMtyP<4je^XCWf4nE&-`vy8^1|5}{VhE${WE)J_P6%5vb;!k zR)1ShTmS5y+5K~R=IB}j{;UO7^pxH+Hy!Tl6lYN1R>LXnLpRwo--?~mdKOqF;6ker zaFNvnc#jnayw_?5Y`10rF1A_#msm3ams+iW%dA;|%dIxR`>ff3E37$yE3LVJ32PqU z{nmWIuN~JD3$N2TyApv?)Jo;k`AmPh)IiWmXY;92tox~_wm-c+d0^jDC)y6@^G6<9 zy}F}wv(80jzq9z>_|E3QaC;ELS`UcYTTc8oJ?-usa$U=OQi>#v|S4A9n7J< z&;cj)LK?aAdJm_2UreS3tYl`O*G~7R2lC0IjHar$(EYwKBO&qCrS+& zCz7#>bC-0PIh;osWh}d*R?KnUw zn6j)=7~Sg6`^LFJ#gKtFzgDJ{VUYO!Zsw2lst^KJN(@FNrX&kfD*-LRcy3x)z0ZwYaFq^NoEW8?Dd~NRX;X|WyAG|tq z_Gb;D#_-9VpEhct$ZJo$^2F(Nr=GfMgkRhJ%I;Hp2>JReUq786GiLpwvE}NFS>KO* zFLJeQ@zpuY=`%0!nHg>hpM3n&xTwDV)YG#1CuZ|^H@?2{WYRVpNHfY?a0SPJY++72>(?7zRn)DCUI`{ zE$H+q`z^<{p9GKTi9lCEFBu0j0~lyF{peD(Kb=p3Nv6Om?c9-MDx0-waFt@2YIz(U|{Uc*K}L zW~{g-G8Yz&h40-rVk{gp64&DM-xxT#>uRj!Ror)zYMttql(PC|W5xeE)bUq$H13G& zWkW%7&P(VR=2Cq~Q{1zli5oivJ~|X29`)i%O&bapg3doGzS<{dWSm1B7HAIXmO*t6 z1q83PLV3wY`{)W)9O5Rvq&f3E?K4@RoKy`o(MQ``!-vCG*BH>zA9@(r;meif6{tfe zAAve_0nuA4{El$~H)Y7MB386uWOSAa<(ox+JT1)-dForuEuzI&ApyON`-guW|G}Za zaV^o%RnlyFqR;A|2excU=qJ`JUwP#Cx`Y$^ri-+1X(wz=!dXe*oe3uq8*1Oeb2la& zn;tey1VPXc=p?X?z{V4COcH53&l=dAa42*0*#=gl6npf@ablLqlwE2`LIlVr?eq~F zT>MDNhID}9QI_mD8;w4 zvqH3Z+_rOesWo{hXZNS_$@I%dvZ;Yo9wmuz$)q#z4t97GYrTmL-o$#I;H_=qUF3<5 zHS7uEvj=;8OLN$d^NJ|RbpMh3aeF7OR|o!`ZvmXtuFjo5Vzj^c$eF;Xxd7OiqW~l3 zf-^^pX8X87pUZ9rPf1v$YCJWn&@nU-;JKGrJ&B(=DyvH5iqUQ_LqE!%dSb&g?r;7B zhDNQYUX8(mo|jL#kc7bz0&nTg!&blwPU5+itljU4%ggCUjB7mCss@Pz6-RXqJlBxd zr>sZvd=u_bRSopWsQoJxudPETInAVask(xGoKf9bS3a(L&8uoKUaHRGbDs)c-WO)v zE|_l-m&)_2XNo-aU67KyxCdTt$2eW2>x6kYoz3Rjp&n$d6Z<`^X_H`gKq}bXdQk+# zE$VPN;abKCA7IRF1H((Pj00|!PYv{@Ig<-!2J(rpy#`$+=GQ@BErIm}HW1j9h%s9G z2@u)42y7y-nZO1D4-t5nz#{~<0D#|m*jJ}FvCf;==;c_O9l9;Vqd13JRL!C6bbint zaG+}C?Y$sghkxg%0N_(g*Nhk&-_8zi8J*n$?9GQpW_SE-?WnmCiA(0jaYKDBQ!&pN zRj8!6MD568fsPp)hqv%t5TJ2M&k>yFXKQ6Gfuhz(`N=_o^Vi^NAJpM$B+GcD962r_ z*XV`di#usnE*Ri!_bL8&3mM07Wt`wwNyeEJ=TY+2uTP4Tci3}3efXf_GLX(VhYB(2 zEaXaO;R#D;Qb+gWSCGyO{0D}DOhP-+oN7PV-rH`qr`r#;ztDcTq^A=4=lWJiYuyu* z*0wSx)d3lfi*fqrvRY$AMaR$9x1u50}>(2 z?zGo~0!q+UqTWWJo2qyi)nzj{^%uS4%_X9n@6yV(R=a-H?xE0t4G9TPpG`Ef$^tmE( zcv8Y5wZT)P3LOLW+yu|Q#A?b49AT%D)y_T7QB*s>l%rJZLHFD|M?%!CW?eLKGJ))0 z0BZ?kFbPQk33c3KC=k#vy^2&YFU?tMNQ0&61>?o%!Q4WH5SZHswFM+APzZs!5lb{Y z%*`vt_&h?Q4@`kms-BNHP_2$Z=ukC$RG#V@BvyeL(bZ#zpFrv9QJbfIpB^a@%=S|H z^h8lIv8QIQ3rb)MPq?SvE8){qk0`DLwwk@Du7RG1@A_+}7twr z-)`mF^M_M;@b>)rignfX(Zvm9g5-=&S zSz=wnnHU_?cX7yrwP#9Vm}AKClFVW@36Zl|F>nnJI$Cf3P&-L~W+GMB0-55ynZuX)!^~ee1i8`r$V$MIg^O`y5j4^61 z8ZkPDpMAgQ-5wzC_WbC1e7)B?{QT(RPGBP_U))(VJI4+BToE}uDPhs*q?v|DMinZH zFBzR@3|TDDG1izQ;s`stBz>>7q8Kd=KP!s`ik60ZcrFoO#-g48Mc+AI0dW>SO}PvP zmqyb_f@9OvN;XY-WN26p&jqAi0Mo`$z%qt{O5Q4HWC<(;J?k!p6~LY_?#bHGXVN6M|mSJjAHpPa#G%NS{$e5l9V4kOapUJD*l7J~mPNjW*j95;_ zf(?MJHdwI<_o&U&zE6+TM1v`*W=|9)dJwbBDeHlpzUuZsQlbZDo!%*Ah_Erm@BsAZAlfA|9T9E3Kg8^D?af(-q_QYtP2mQWtM> zf{Y$EVz&B0C+C{PcEXrC!ITG=+pqCV04}5NEYl~}IrQu^p0D`+a>n;zaLygWIwffl znSrEdIWn2eU?I_d88vs{-?^OEEAa|a;iyE>Cy8Ig6ySEwi=gMzTpbITAlK!&*3;dB`fbFkRen}&0aWs`pa z*_O-rSm#)LgUiSE5g*%E$;b9xJb+x{V+Te|Fd_O}aWi;I!Xm59lcOG75?nI&5%Z!C z$_45zCg$Z93v|rbM|_NOQJ?sji!wgu;ujC_>6H0S-4RZshA!vso;SUnxq~MZ5?`eq(TpmPFUyk%3nTOZ?*h&TM`=#dlT-FSG)G< zk5l5*@h&E!C3Nl}N~W8s6xXVm zT)B^{Vsa;OX0l>B{|z7%&%I;ujVE_qGZvjWIK1Mc#phF(jYqF;erUva=E7b8A-=U- zfJV9t85h?BjF`LdZ9X$@(C3Pq!BY|zS#6#i_281=lJN}5Y}5qG1?nv(nawR0=$P@$ z`4*mw`WK*V5&=;3N#YmRbF>1Y=QYH)X_kj!aeKnuG|PkWvkoi!^_w1l9V~A?4XLSS zd3=GztDy?HPS_t6(|Z*i*pL4*rgwBKzWFQ7^k`X=t*AtS!3|%owiK|ULVlG@o^vhH zLwdO!XN9crxrmbX{NqKGgT2s|`s8{E&h8$H%5M#&pp)OXtY5?Zo_s?6R!W9~kTMO5 zaA6+ga6xc>qHOJbJ{ESY^N<5V1J&`Y_QdjSqR)L~KCU=Ct37R=_DQ#?IF#%InXW|? z_R)hZxbrd+YG%y)wBQZz|GYAJe==!WNv_lP|8eaofjAna9i%>#Yy zidtZ|eLv#U_u~uSk0s(RsBhel@NGRG&C~be3*V1r>OKhUw(rM`>HG1;?gv&(Te|Et zZp8H&D2uCjwNOhj zk!kz5ZQ_TN!iNxrleTS|PD`Of8QaO1TB@$pM4fobtFmsC;^jZclV;^-qn_p9@1Pgl z#guepwmR%}a^GO@K(>#!50A-|F~m;tU=Hr1TnDG9Om-$jpt6V^#D3a;f_A<^ZOw&T zEY$FI7eiM&)?LVd7Sx-c)=%#Gg{zP+Ie+xBvHg?A_;;Ut{mHivywN@0xTM&)|o_&d$lk^e+^{?rS_ZH22M~r2|2pL#5V%|I4Q8bs08}zyC zX7H4RMPnKLMxq#1=olmB*_Ws}NiPvlv)}mToe`a?wo{Ps%S86ygR@bk$nJH5s^zO& z?T<-8m3~f7Lx!UD3z;NuUM`*wh2aHk6r}6f8m5x&^>~A$iw!1#CdT(<{7No_wlV%zgH+6F#oC=K0<$W z4G8L?+I3U_r4^oV&)Yn0`t-bPIjUxl6(xEahBqo#0UAB+$1%<-4yZ&XQ;Bf5JXhYU zlGMKP_4+SWJg7vq`!6Y(N>rOxT3m2Xlu}f?r`|j6bET{A_#GJ|N|~$u3#Bqw@ZLe6 zk*bunI@eb&;MJjP5YI>*x|pDYx*22YW|KwgW>ba-9>?>vYXUnwE%_n} z((o#F!%={_nb3`S9uO6w!n0I>^(}OBrja5ph5Er`uUB!!)XnQahzSB#kO>8q$FB&V z=1$Nv6%w)56pPJ68+Q{%`a9Ft?IbYfvb=%8O$0mFFu1lXmZg(S0aUpDIz)VrV;>@S z!WpN1Acq*96Mz33dcDwl<_AvKEYZwO0)Ie1lWYO@8w5yjvEL-Hfk6Gh9p>hU#*|Yx zixFiC{G-$a#qroFfU;no^<}AzWvR{HRHy6p&E3K~JF4P;I@fV>?j+v1!D+_1q+2*g zcu>#^9u&kj>T$4V6YM?MY@`=kiY0RcNjA3FW+nEI&_XZ%o&WQi2e>Pe+4S8_uWuSL z_PoC7?Z@BP2ITa%adUCeTnzNRJ~9|LkB&q^M{x4CEBpwAWF08dC*r1p7gRH0);D`_}_20uc?J~ahOfg*L8`hYX> zG6~W0=}w*ky+G0O=}nS`BWU^U$0emeiEFwiC1-pmjI$) z6-F3w^}EH@?^aj8TU`BadG)*5)$e9kzuR;5k5{vY0wqdZ zeagK=XTa5WDcIGQ`MCNL$JKZ7i|aUA4bfP&352*JGl47vEH{A&@d7(WAqWyXE%^DK z1%=QL-!o&9zGoq3dGkFKo?%a?3HRvE_?}IrD~J@n>w3RQN{HzVcfBJ< zO!}UgR^uJ@J!|U9Y}?VSWtwS4gJJ}pothYdb~c+05g^61qHFAMA?rJ`1`y&!Qub}%DGxV#knY$B1bkx6pBDC2q!65O9u3zoqCeb>LmmyvZQ(j8nG+TVo~Ru&xP)uBgZ+b9Y=ij{ZAwfE zZmQSLIS$in9s4ZKM{$VGzrov3tb4C5b3b2tLBDM55YFiC+BZ3)cat-Ew~!ZvGrCZq zc9knouP1SB}eLkE0e25`VhuYEylA2r`KLH)K5J-|H%MSdaIRaPT zq74>F1vwCl@nzB<`FbhYa0t7_VJ8!4s8~wJ;yc^XL@ulOd6r^9H^LTlNvl;mns^q< z`g}Y-MX%UWx@x~&^;=-RkyNV25<#cFt7|}}A-PU2-PP$)n_5WhYgg?GtSpkMW=|BQ zE7wg`52UmPxvq8(BqfvUYSXHv;@Wg_skqL$gJTi`sYG%N!o*mB&yF5szjeRxp{DIqg&a z6VI>qSMkPw>vsw7!|nYka1>r@qG|zsp@(UfRb^rtBnC?OJ~Aj4)f)-zfFyN zhrnwDP7@&4%S4; z#KAhYjhc{~*Vu!UWQ!dy5q6A#S_=7NO6~w~$hoI{d4jJ_aC?ron|Rel9-EfYD>i!Z z4gve8s7X6MIG@~7dgJoir(vzBd4k@G*A^`O;kF-a8(r|gx#QI*NqYkfN5_E~Mjis07Mz`pqzMs&I9odSl=1wV8Jl&WAMgn2{!q2~Yxdy7Vh<6fx3 zfW%{dXW96qMYs?Go-pv~n5+m_F1pdqp?BZ)K(LNas1kw^AbOIp=?R`%z^?+7Tw-Gt z$1Q`WXoZ#?e!`hYux!h#_WdRxIay&^Q*~8Ac&aw75~6C;Ng?W%!Beoa^TX3S zAAv#nU(jz=6o5W`r!O9j?WHh7`FI8d=7|C78tA$ERvUyldQ>F(?TT0y?NJ{jesrGgAZ2VpG`WB?XP^Z1gTV3MrYFgB4hAm;* z>cZa+SI_Fw!mKVcE37W9Q#RVX;%V)XAt^E1TI-As=@eXtPFY=MNvlhnHT#ZOUFNt} zmmx4An*?|z0$p#XB*{UOw7ZZ<^u_EhN=)*;+VghIE=*kd2pzhm&4jOciSSM?HdR_p z{u&LOp+>HG#H9UYyN^s-NLsY~hshr#M;ER7*}MyZv5v=ziykAJ$$V)u`D8QvJ*O=u zUtSb!ZHvh@NbS=W6YruIVmJdy6ORQuVAz+eQsk*G&0T)r z(YyWZ6A->&b~h_FOK8kbl|C#j33crgl9B`-Pp|!K8B&bEVoSi?w1pO1)TS+8)b2nw z2d(dG>iP>z*k6AUomWdu<`ubi!EbK zdrm%awQW9LDt&sxTL&m&U_L&Vd6Fk2#OtR|Z{Vp>g+jo2>wu(@2#C>tyvCfpg!=03 z7=%^~V{K^#nP*Z(bH%tpAH2JoW$=`QMPmh-XNY1{p?FU=GtW@=B?HY#dWnF9nP(D4 z7mmaOrKorXF?%ynX26bKkSkq2v-)j0- zt>I&BU;GYF%BE_ count_tokens("a") + +def test_count_tokens_more_accurate_than_split(): + # tiktoken counts "don't" as 2 tokens, split counts it as 1 word + text = "don't do that" + assert count_tokens(text) >= len(text.split()) + + +# --- check_and_increment_quota --- + +def test_allowed_within_daily_token_limit(db): + user_id = make_user_and_quota(db, daily_tokens=1000) + assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is True + +def test_denied_when_daily_tokens_exceeded(db): + user_id = make_user_and_quota(db, daily_tokens=50) + assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False + +def test_denied_when_monthly_tokens_exceeded(db): + user_id = make_user_and_quota(db, monthly_tokens=50) + assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False + +def test_denied_when_daily_requests_exceeded(db): + user_id = make_user_and_quota(db, daily_requests=1) + check_and_increment_quota(db, user_id, tokens=0, requests=1) + assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False + +def test_denied_when_monthly_requests_exceeded(db): + user_id = make_user_and_quota(db, monthly_requests=1) + check_and_increment_quota(db, user_id, tokens=0, requests=1) + assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False + +def test_increments_both_daily_and_monthly_counters(db): + user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000, + daily_requests=100, monthly_requests=1000) + check_and_increment_quota(db, user_id, tokens=50, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 50 + assert usage.tokens_used_month == 50 + assert usage.requests_today == 1 + assert usage.requests_month == 1 + +def test_creates_usage_record_on_first_call(db): + user_id = make_user_and_quota(db, daily_tokens=1000) + assert db.query(Usage).filter(Usage.user_id == user_id).first() is None + + check_and_increment_quota(db, user_id, tokens=10, requests=1) + + assert db.query(Usage).filter(Usage.user_id == user_id).first() is not None + +def test_no_quota_allows_any_request(db): + user_id = make_user_and_quota(db) # all limits None + assert check_and_increment_quota(db, user_id, tokens=999999, requests=9999) is True + +def test_cumulative_usage_across_calls(db): + user_id = make_user_and_quota(db, daily_tokens=200) + check_and_increment_quota(db, user_id, tokens=100, requests=1) + check_and_increment_quota(db, user_id, tokens=99, requests=1) + # 199 used, 1 remaining – exactly 1 more token should pass + assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is True + # Now 200 used – next request must fail + assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is False + + +# --- Reset logic --- + +def test_daily_reset_restores_access(db): + user_id = make_user_and_quota(db, daily_tokens=100) + check_and_increment_quota(db, user_id, tokens=90, requests=1) + + # Backdate daily_reset_at to yesterday + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) + db.commit() + + # Should pass again after reset + assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 90 + +def test_daily_reset_does_not_affect_monthly_counter(db): + user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000) + check_and_increment_quota(db, user_id, tokens=50, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) + db.commit() + + check_and_increment_quota(db, user_id, tokens=50, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 50 + assert usage.tokens_used_month == 100 # cumulative across days + +def test_monthly_reset_restores_access(db): + user_id = make_user_and_quota(db, monthly_tokens=100) + check_and_increment_quota(db, user_id, tokens=90, requests=1) + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.monthly_reset_at = datetime.utcnow() - timedelta(days=32) + db.commit() + + assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True + + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_month == 90 + +def test_failed_quota_check_still_commits_reset(db): + user_id = make_user_and_quota(db, daily_tokens=100, daily_requests=5) + check_and_increment_quota(db, user_id, tokens=80, requests=1) + + # Backdate so a reset fires, but the new request still exceeds the limit + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) + usage.tokens_used_today = 80 + db.commit() + + # After reset tokens_used_today = 0; 200 tokens exceeds 100 limit + result = check_and_increment_quota(db, user_id, tokens=200, requests=1) + assert result is False + + # Reset must still be persisted so the next request sees fresh counters + db.expire_all() + usage = db.query(Usage).filter(Usage.user_id == user_id).first() + assert usage.tokens_used_today == 0