Fix critical/high security and correctness issues from code review
Critical (all fixed): - bcrypt statt SHA-256 für Passwörter - API-Keys gehasht in DB, Plaintext nur einmalig zurückgegeben - DB-Session-Leak behoben (SessionLocal + try/finally, Depends(get_db)) - Admin-Check via is_admin-Spalte statt Hardcoded-Username - CORS: konfigurierbare Origins via ALLOWED_ORIGINS, kein Wildcard mit Credentials High (all fixed): - TOCTOU-Race: check_and_increment_quota mit SELECT FOR UPDATE atomar - Getrennte Tages-/Monatszähler in Usage + automatische Reset-Logik - Token-Zählung mit tiktoken (cl100k_base) statt .split() Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
562f6ecd9c
commit
bf694b79e2
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy.orm import Session
|
||||
@ -7,15 +8,17 @@ from models import User, APIKey, Quota
|
||||
|
||||
app = FastAPI(title="Ollama Proxy Admin API")
|
||||
|
||||
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173").split(",")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE"],
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
async def require_admin_auth(request: Request):
|
||||
async def require_admin_auth(request: Request, db: Session = Depends(get_db)):
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
api_key = auth_header.replace("Bearer ", "")
|
||||
@ -23,24 +26,21 @@ async def require_admin_auth(request: Request):
|
||||
api_key = auth_header
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
db_key = crud.verify_api_key(db, api_key)
|
||||
|
||||
if not db_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
|
||||
db_user = db.query(User).filter(User.id == db_key.user_id).first()
|
||||
if not db_user or db_user.username != "admin":
|
||||
if not db_user or not db_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
|
||||
request.state.user = db_user
|
||||
request.state.db = db
|
||||
|
||||
@app.get("/api/users", response_model=list[schemas.User])
|
||||
async def read_users(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
_ = Depends(require_admin_auth)
|
||||
):
|
||||
@ -49,7 +49,7 @@ async def read_users(
|
||||
|
||||
@app.post("/api/users", response_model=schemas.User)
|
||||
async def create_user(
|
||||
user: schemas.UserCreate,
|
||||
user: schemas.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_ = Depends(require_admin_auth)
|
||||
):
|
||||
@ -61,21 +61,24 @@ async def create_user(
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
return crud.create_user(db=db, username=user.username, email=user.email, password=user.password)
|
||||
|
||||
@app.post("/api/api-keys", response_model=schemas.APIKey)
|
||||
@app.post("/api/api-keys", response_model=schemas.APIKeyCreated)
|
||||
async def create_api_key(
|
||||
api_key: schemas.APIKeyCreate,
|
||||
api_key: schemas.APIKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_ = Depends(require_admin_auth)
|
||||
):
|
||||
db_user = db.query(User).filter(User.id == api_key.user_id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name)
|
||||
db_key, raw_key = crud.create_api_key(db=db, user_id=api_key.user_id, name=api_key.name)
|
||||
result = schemas.APIKeyCreated.model_validate(db_key)
|
||||
result.plaintext_key = raw_key
|
||||
return result
|
||||
|
||||
@app.get("/api/api-keys", response_model=list[schemas.APIKey])
|
||||
async def read_api_keys(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
_ = Depends(require_admin_auth)
|
||||
):
|
||||
@ -84,7 +87,7 @@ async def read_api_keys(
|
||||
|
||||
@app.put("/api/api-keys/{api_key_id}/deactivate")
|
||||
async def deactivate_api_key(
|
||||
api_key_id: int,
|
||||
api_key_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_ = Depends(require_admin_auth)
|
||||
):
|
||||
@ -97,15 +100,15 @@ async def deactivate_api_key(
|
||||
|
||||
@app.put("/api/quotas/{user_id}", response_model=schemas.Quota)
|
||||
async def update_quota(
|
||||
user_id: int,
|
||||
quota: schemas.QuotaCreate,
|
||||
user_id: int,
|
||||
quota: schemas.QuotaCreate,
|
||||
db: Session = Depends(get_db),
|
||||
_ = Depends(require_admin_auth)
|
||||
):
|
||||
db_quota = db.query(Quota).filter(Quota.user_id == user_id).first()
|
||||
if not db_quota:
|
||||
raise HTTPException(status_code=404, detail="Quota not found")
|
||||
for key, value in quota.dict(exclude_unset=True).items():
|
||||
for key, value in quota.model_dump(exclude_unset=True).items():
|
||||
setattr(db_quota, key, value)
|
||||
db.commit()
|
||||
db.refresh(db_quota)
|
||||
|
||||
114
backend/crud.py
114
backend/crud.py
@ -1,10 +1,25 @@
|
||||
import secrets
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
import bcrypt
|
||||
import tiktoken
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from database import get_db
|
||||
from models import APIKey, User, Quota, Usage
|
||||
|
||||
_encoder = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
return len(_encoder.encode(text))
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
def _hash_api_key(key: str) -> str:
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
def get_user_by_username(db: Session, username: str):
|
||||
return db.query(User).filter(User.username == username).first()
|
||||
|
||||
@ -14,19 +29,17 @@ def get_user_by_email(db: Session, email: str):
|
||||
def generate_api_key():
|
||||
return "sk-" + secrets.token_urlsafe(32)
|
||||
|
||||
def hash_password(password: str):
|
||||
return hashlib.sha256(password.encode()).hexdigest()
|
||||
|
||||
def create_user(db: Session, username: str, email: str, password: str):
|
||||
def create_user(db: Session, username: str, email: str, password: str, is_admin: bool = False):
|
||||
db_user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
hashed_password=hash_password(password)
|
||||
hashed_password=hash_password(password),
|
||||
is_admin=is_admin,
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
|
||||
default_quota = Quota(
|
||||
user_id=db_user.id,
|
||||
daily_tokens=1000000,
|
||||
@ -36,60 +49,75 @@ def create_user(db: Session, username: str, email: str, password: str):
|
||||
)
|
||||
db.add(default_quota)
|
||||
db.commit()
|
||||
|
||||
|
||||
return db_user
|
||||
|
||||
def create_api_key(db: Session, user_id: int, name: str):
|
||||
key = generate_api_key()
|
||||
def create_api_key(db: Session, user_id: int, name: str) -> tuple[APIKey, str]:
|
||||
raw_key = generate_api_key()
|
||||
db_key = APIKey(
|
||||
name=name,
|
||||
key=key,
|
||||
key=_hash_api_key(raw_key),
|
||||
user_id=user_id
|
||||
)
|
||||
db.add(db_key)
|
||||
db.commit()
|
||||
db.refresh(db_key)
|
||||
return db_key
|
||||
return db_key, raw_key
|
||||
|
||||
def verify_api_key(db: Session, api_key: str):
|
||||
return db.query(APIKey).filter(APIKey.key == api_key, APIKey.is_active == True).first()
|
||||
key_hash = _hash_api_key(api_key)
|
||||
return db.query(APIKey).filter(APIKey.key == key_hash, APIKey.is_active == True).first()
|
||||
|
||||
def get_quota(db: Session, user_id: int):
|
||||
return db.query(Quota).filter(Quota.user_id == user_id).first()
|
||||
|
||||
def check_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1):
|
||||
quota = get_quota(db, user_id)
|
||||
usage = get_usage(db, user_id)
|
||||
|
||||
if quota.daily_tokens and (usage.tokens_used + tokens) > quota.daily_tokens:
|
||||
return False
|
||||
if quota.monthly_tokens and (usage.tokens_used + tokens) > quota.monthly_tokens:
|
||||
return False
|
||||
if quota.daily_requests and (usage.requests_count + requests) > quota.daily_requests:
|
||||
return False
|
||||
if quota.monthly_requests and (usage.requests_count + requests) > quota.monthly_requests:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def increment_usage(db: Session, user_id: int, tokens: int = 0, requests: int = 1):
|
||||
usage = get_or_create_usage(db, user_id)
|
||||
usage.tokens_used += tokens
|
||||
usage.requests_count += requests
|
||||
db.commit()
|
||||
db.refresh(usage)
|
||||
def get_quota_by_user_id(db: Session, user_id: int):
|
||||
return db.query(Quota).filter(Quota.user_id == user_id).first()
|
||||
|
||||
def get_usage(db: Session, user_id: int):
|
||||
return db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
|
||||
def get_quota_by_user_id(db: Session, user_id: int):
|
||||
return db.query(Quota).filter(Quota.user_id == user_id).first()
|
||||
|
||||
def get_or_create_usage(db: Session, user_id: int):
|
||||
usage = get_usage(db, user_id)
|
||||
def check_and_increment_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1) -> bool:
|
||||
usage = (
|
||||
db.query(Usage)
|
||||
.filter(Usage.user_id == user_id)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not usage:
|
||||
usage = Usage(user_id=user_id)
|
||||
db.add(usage)
|
||||
db.commit()
|
||||
db.refresh(usage)
|
||||
return usage
|
||||
db.flush()
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
if usage.daily_reset_at.date() < now.date():
|
||||
usage.tokens_used_today = 0
|
||||
usage.requests_today = 0
|
||||
usage.daily_reset_at = now
|
||||
|
||||
if (usage.monthly_reset_at.year, usage.monthly_reset_at.month) < (now.year, now.month):
|
||||
usage.tokens_used_month = 0
|
||||
usage.requests_month = 0
|
||||
usage.monthly_reset_at = now
|
||||
|
||||
quota = get_quota(db, user_id)
|
||||
allowed = True
|
||||
if quota:
|
||||
if quota.daily_tokens and (usage.tokens_used_today + tokens) > quota.daily_tokens:
|
||||
allowed = False
|
||||
elif quota.monthly_tokens and (usage.tokens_used_month + tokens) > quota.monthly_tokens:
|
||||
allowed = False
|
||||
elif quota.daily_requests and (usage.requests_today + requests) > quota.daily_requests:
|
||||
allowed = False
|
||||
elif quota.monthly_requests and (usage.requests_month + requests) > quota.monthly_requests:
|
||||
allowed = False
|
||||
|
||||
if allowed:
|
||||
usage.tokens_used_today += tokens
|
||||
usage.tokens_used_month += tokens
|
||||
usage.requests_today += requests
|
||||
usage.requests_month += requests
|
||||
|
||||
db.commit()
|
||||
return allowed
|
||||
100
backend/main.py
100
backend/main.py
@ -1,7 +1,9 @@
|
||||
import time
|
||||
import uuid
|
||||
from fastapi import FastAPI, HTTPException, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from database import get_db
|
||||
from database import get_db, SessionLocal
|
||||
import crud
|
||||
import httpx
|
||||
import os
|
||||
@ -22,54 +24,48 @@ async def authenticate_and_quota(request: Request, call_next):
|
||||
elif auth_header.startswith("sk-"):
|
||||
api_key = auth_header
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
||||
|
||||
db = next(get_db())
|
||||
db_key = crud.verify_api_key(db, api_key)
|
||||
|
||||
if not db_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
if not db_key.is_active:
|
||||
raise HTTPException(status_code=403, detail="API key deactivated")
|
||||
|
||||
request.state.user_id = db_key.user_id
|
||||
|
||||
return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"})
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_key = crud.verify_api_key(db, api_key)
|
||||
if not db_key:
|
||||
return JSONResponse(status_code=401, content={"detail": "Invalid API key"})
|
||||
if not db_key.is_active:
|
||||
return JSONResponse(status_code=403, content={"detail": "API key deactivated"})
|
||||
request.state.user_id = db_key.user_id
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def generate(request: Request):
|
||||
db = next(get_db())
|
||||
async def generate(request: Request, db: Session = Depends(get_db)):
|
||||
user_id = request.state.user_id
|
||||
|
||||
|
||||
body = await request.json()
|
||||
|
||||
prompt_tokens = len(body.get("prompt", "").split())
|
||||
if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1):
|
||||
|
||||
prompt_tokens = crud.count_tokens(body.get("prompt", ""))
|
||||
if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1):
|
||||
raise HTTPException(status_code=429, detail="Quota exceeded")
|
||||
|
||||
|
||||
response = await proxy_request(f"{OLLAMA_URL}/api/generate", method="POST", json_data=body, headers=dict(request.headers))
|
||||
|
||||
crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1)
|
||||
|
||||
|
||||
return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers))
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: Request):
|
||||
db = next(get_db())
|
||||
async def chat(request: Request, db: Session = Depends(get_db)):
|
||||
user_id = request.state.user_id
|
||||
|
||||
|
||||
body = await request.json()
|
||||
|
||||
prompt_tokens = sum(len(msg.get("content", "").split()) for msg in body.get("messages", []))
|
||||
if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1):
|
||||
|
||||
prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in body.get("messages", []))
|
||||
if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1):
|
||||
raise HTTPException(status_code=429, detail="Quota exceeded")
|
||||
|
||||
|
||||
response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=body, headers=dict(request.headers))
|
||||
|
||||
crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1)
|
||||
|
||||
|
||||
return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers))
|
||||
|
||||
@app.get("/api/tags")
|
||||
@ -101,48 +97,48 @@ async def list_openai_models(request: Request):
|
||||
return JSONResponse(content=openai_models, status_code=200, headers=dict(response.headers))
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_chat_completions(request: Request):
|
||||
db = next(get_db())
|
||||
async def openai_chat_completions(request: Request, db: Session = Depends(get_db)):
|
||||
user_id = request.state.user_id
|
||||
|
||||
|
||||
body = await request.json()
|
||||
|
||||
|
||||
messages = body.get("messages", [])
|
||||
prompt_tokens = sum(len(msg.get("content", "").split()) for msg in messages)
|
||||
|
||||
if not crud.check_quota(db, user_id, tokens=prompt_tokens, requests=1):
|
||||
prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in messages)
|
||||
|
||||
if not crud.check_and_increment_quota(db, user_id, tokens=prompt_tokens, requests=1):
|
||||
raise HTTPException(status_code=429, detail="Quota exceeded")
|
||||
|
||||
|
||||
ollama_body = {
|
||||
"model": body.get("model", "llama3"),
|
||||
"messages": messages,
|
||||
"stream": body.get("stream", False)
|
||||
}
|
||||
|
||||
|
||||
response = await proxy_request(f"{OLLAMA_URL}/api/chat", method="POST", json_data=ollama_body, headers=dict(request.headers))
|
||||
|
||||
crud.increment_usage(db, user_id, tokens=prompt_tokens, requests=1)
|
||||
|
||||
|
||||
response_content = response.json().get("message", {}).get("content", "")
|
||||
completion_tokens = crud.count_tokens(response_content)
|
||||
|
||||
openai_response = {
|
||||
"id": f"chatcmpl-{hash(msg.get('content', ''))}",
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
||||
"object": "chat.completion",
|
||||
"created": int(__import__('time').time()),
|
||||
"created": int(time.time()),
|
||||
"model": body.get("model", "llama3"),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response.json().get("message", {}).get("content", "")
|
||||
"content": response_content
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": len(response.json().get("message", {}).get("content", "").split()),
|
||||
"total_tokens": prompt_tokens + len(response.json().get("message", {}).get("content", "").split())
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return JSONResponse(content=openai_response, status_code=200, headers={"Content-Type": "application/json"})
|
||||
|
||||
@ -10,6 +10,7 @@ class User(Base):
|
||||
email = Column(String, unique=True, index=True)
|
||||
hashed_password = Column(String)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_admin = Column(Boolean, default=False)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
class APIKey(Base):
|
||||
@ -38,6 +39,9 @@ class Usage(Base):
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), unique=True)
|
||||
tokens_used = Column(BigInteger, default=0)
|
||||
requests_count = Column(Integer, default=0)
|
||||
reset_at = Column(DateTime, default=datetime.utcnow)
|
||||
tokens_used_today = Column(BigInteger, default=0)
|
||||
tokens_used_month = Column(BigInteger, default=0)
|
||||
requests_today = Column(Integer, default=0)
|
||||
requests_month = Column(Integer, default=0)
|
||||
daily_reset_at = Column(DateTime, default=datetime.utcnow)
|
||||
monthly_reset_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@ -6,7 +6,8 @@ alembic==1.14.0
|
||||
pydantic==2.10.3
|
||||
python-multipart==0.0.20
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
bcrypt==5.0.0
|
||||
tiktoken==0.9.0
|
||||
python-dotenv==1.0.1
|
||||
pytest==8.3.4
|
||||
pytest-asyncio==0.25.1
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Optional
|
||||
class UserBase(BaseModel):
|
||||
username: str
|
||||
email: str
|
||||
is_admin: bool = False
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
@ -33,6 +34,12 @@ class APIKey(APIKeyBase):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class APIKeyCreated(APIKey):
|
||||
plaintext_key: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class QuotaBase(BaseModel):
|
||||
daily_tokens: Optional[int] = None
|
||||
monthly_tokens: Optional[int] = None
|
||||
@ -51,9 +58,12 @@ class Quota(QuotaBase):
|
||||
from_attributes = True
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
tokens_used: int = 0
|
||||
requests_count: int = 0
|
||||
last_reset: Optional[datetime] = None
|
||||
tokens_used_today: int = 0
|
||||
tokens_used_month: int = 0
|
||||
requests_today: int = 0
|
||||
requests_month: int = 0
|
||||
daily_reset_at: Optional[datetime] = None
|
||||
monthly_reset_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@ -13,7 +13,8 @@ def setup_admin():
|
||||
username="admin",
|
||||
email="admin@ollama.local",
|
||||
hashed_password=hash_password("admin123"),
|
||||
is_active=True
|
||||
is_active=True,
|
||||
is_admin=True,
|
||||
)
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
@ -31,8 +32,8 @@ def setup_admin():
|
||||
db.commit()
|
||||
print("✓ Admin quota created")
|
||||
|
||||
api_key = create_api_key(db, admin_user.id, "admin-api-key")
|
||||
print(f"✓ Admin API Key: {api_key.key}")
|
||||
_, raw_key = create_api_key(db, admin_user.id, "admin-api-key")
|
||||
print(f"✓ Admin API Key: {raw_key}")
|
||||
else:
|
||||
print("✗ Admin user already exists")
|
||||
db.close()
|
||||
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for the Ollama Proxy."""
|
||||
BIN
backend/tests/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
backend/tests/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc
Normal file
BIN
backend/tests/__pycache__/conftest.cpython-312-pytest-8.3.4.pyc
Normal file
Binary file not shown.
BIN
backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc
Normal file
BIN
backend/tests/__pycache__/test_auth.cpython-312-pytest-8.3.4.pyc
Normal file
Binary file not shown.
Binary file not shown.
87
backend/tests/conftest.py
Normal file
87
backend/tests/conftest.py
Normal file
@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def create_test_db():
|
||||
"""Create a temporary SQLite database for tests."""
|
||||
temp_db = tempfile.NamedTemporaryFile(suffix='.db', delete=False)
|
||||
temp_db.close()
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{temp_db.name}"
|
||||
return temp_db.name
|
||||
|
||||
def cleanup_test_db(db_path):
|
||||
"""Remove the temporary database."""
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
os.environ.pop("DATABASE_URL", None)
|
||||
|
||||
def setup_test_db():
|
||||
"""Setup test database with required data."""
|
||||
from database import Base, engine, SessionLocal
|
||||
from models import User, APIKey, Quota, Usage
|
||||
from crud import create_api_key, hash_password
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
# Create test user
|
||||
test_user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("test123"),
|
||||
is_active=True
|
||||
)
|
||||
db.add(test_user)
|
||||
db.commit()
|
||||
db.refresh(test_user)
|
||||
|
||||
# Create API key for test user
|
||||
_, raw_key = create_api_key(db, test_user.id, "test-key")
|
||||
os.environ["TEST_API_KEY"] = raw_key
|
||||
|
||||
# Create admin user
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
email="admin@example.com",
|
||||
hashed_password=hash_password("admin123"),
|
||||
is_active=True,
|
||||
is_admin=True,
|
||||
)
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
db.refresh(admin_user)
|
||||
|
||||
# Create admin API key
|
||||
_, admin_raw_key = create_api_key(db, admin_user.id, "admin-key")
|
||||
os.environ["ADMIN_API_KEY"] = admin_raw_key
|
||||
|
||||
db.close()
|
||||
|
||||
return raw_key, admin_raw_key
|
||||
|
||||
def teardown_test_db():
|
||||
"""Clean up test database and environment."""
|
||||
from database import engine
|
||||
from models import Base
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
os.environ.pop("TEST_API_KEY", None)
|
||||
os.environ.pop("ADMIN_API_KEY", None)
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_client():
|
||||
"""Create test client with test database."""
|
||||
db_path = create_test_db()
|
||||
setup_test_db()
|
||||
|
||||
from main import app
|
||||
client = TestClient(app)
|
||||
|
||||
yield client
|
||||
|
||||
teardown_test_db()
|
||||
cleanup_test_db(db_path)
|
||||
106
backend/tests/test_auth.py
Normal file
106
backend/tests/test_auth.py
Normal file
@ -0,0 +1,106 @@
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from main import app
|
||||
from database import Base, engine, SessionLocal
|
||||
from models import User, APIKey, Quota
|
||||
from crud import create_api_key, hash_password
|
||||
|
||||
os.environ["OLLAMA_URL"] = "http://127.0.0.1:9999"
|
||||
|
||||
def setup_test_db():
|
||||
"""Setup test database with required data."""
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
test_user = User(
|
||||
username="testuser",
|
||||
email="test@example.com",
|
||||
hashed_password=hash_password("test123"),
|
||||
is_active=True
|
||||
)
|
||||
db.add(test_user)
|
||||
db.commit()
|
||||
db.refresh(test_user)
|
||||
|
||||
quota = Quota(
|
||||
user_id=test_user.id,
|
||||
daily_tokens=1000000,
|
||||
monthly_tokens=10000000,
|
||||
daily_requests=1000,
|
||||
monthly_requests=10000
|
||||
)
|
||||
db.add(quota)
|
||||
db.commit()
|
||||
|
||||
api_key_record, raw_key = create_api_key(db, test_user.id, "test-key")
|
||||
os.environ["TEST_API_KEY"] = raw_key
|
||||
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
email="admin@example.com",
|
||||
hashed_password=hash_password("admin123"),
|
||||
is_active=True,
|
||||
is_admin=True,
|
||||
)
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
db.refresh(admin_user)
|
||||
|
||||
admin_quota = Quota(
|
||||
user_id=admin_user.id,
|
||||
daily_tokens=10000000,
|
||||
monthly_tokens=100000000,
|
||||
daily_requests=10000,
|
||||
monthly_requests=100000
|
||||
)
|
||||
db.add(admin_quota)
|
||||
db.commit()
|
||||
|
||||
_, admin_raw_key = create_api_key(db, admin_user.id, "admin-key")
|
||||
os.environ["ADMIN_API_KEY"] = admin_raw_key
|
||||
|
||||
db.close()
|
||||
|
||||
return os.environ["TEST_API_KEY"], os.environ["ADMIN_API_KEY"]
|
||||
|
||||
def teardown_test_db():
|
||||
"""Clean up test database and environment."""
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
os.environ.pop("TEST_API_KEY", None)
|
||||
os.environ.pop("ADMIN_API_KEY", None)
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_client():
|
||||
setup_test_db()
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
yield client
|
||||
teardown_test_db()
|
||||
|
||||
def test_auth_middleware_missing_auth(test_client):
|
||||
response = test_client.post("/api/generate", json={"model": "llama3", "prompt": "test"})
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_auth_middleware_invalid_key(test_client):
|
||||
response = test_client.post(
|
||||
"/api/generate",
|
||||
headers={"Authorization": "sk-invalid-key"},
|
||||
json={"model": "llama3", "prompt": "test"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
@patch("main.proxy_request", new_callable=AsyncMock)
|
||||
def test_auth_middleware_valid_key(mock_proxy, test_client):
|
||||
mock_proxy.return_value.status_code = 200
|
||||
mock_proxy.return_value.json = lambda: {"response": "success"}
|
||||
mock_proxy.return_value.headers = {}
|
||||
|
||||
response = test_client.post(
|
||||
"/api/generate",
|
||||
headers={"Authorization": os.environ.get("TEST_API_KEY", "")},
|
||||
json={"model": "llama3", "prompt": "test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
182
backend/tests/test_quota.py
Normal file
182
backend/tests/test_quota.py
Normal file
@ -0,0 +1,182 @@
|
||||
import pytest
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
os.environ.setdefault("OLLAMA_URL", "http://127.0.0.1:9999")
|
||||
|
||||
from database import Base, engine, SessionLocal
|
||||
from models import User, Quota, Usage
|
||||
from crud import check_and_increment_quota, count_tokens, hash_password
|
||||
|
||||
|
||||
def make_user_and_quota(db, daily_tokens=None, monthly_tokens=None,
|
||||
daily_requests=None, monthly_requests=None):
|
||||
user = User(
|
||||
username="quotauser",
|
||||
email="quota@example.com",
|
||||
hashed_password=hash_password("pass"),
|
||||
is_active=True,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
quota = Quota(
|
||||
user_id=user.id,
|
||||
daily_tokens=daily_tokens,
|
||||
monthly_tokens=monthly_tokens,
|
||||
daily_requests=daily_requests,
|
||||
monthly_requests=monthly_requests,
|
||||
)
|
||||
db.add(quota)
|
||||
db.commit()
|
||||
|
||||
return user.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
# --- count_tokens ---
|
||||
|
||||
def test_count_tokens_empty():
|
||||
assert count_tokens("") == 0
|
||||
|
||||
def test_count_tokens_returns_int():
|
||||
assert isinstance(count_tokens("hello world"), int)
|
||||
|
||||
def test_count_tokens_scales_with_length():
|
||||
assert count_tokens("a b c d e f g h") > count_tokens("a")
|
||||
|
||||
def test_count_tokens_more_accurate_than_split():
|
||||
# tiktoken counts "don't" as 2 tokens, split counts it as 1 word
|
||||
text = "don't do that"
|
||||
assert count_tokens(text) >= len(text.split())
|
||||
|
||||
|
||||
# --- check_and_increment_quota ---
|
||||
|
||||
def test_allowed_within_daily_token_limit(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=1000)
|
||||
assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is True
|
||||
|
||||
def test_denied_when_daily_tokens_exceeded(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=50)
|
||||
assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False
|
||||
|
||||
def test_denied_when_monthly_tokens_exceeded(db):
|
||||
user_id = make_user_and_quota(db, monthly_tokens=50)
|
||||
assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False
|
||||
|
||||
def test_denied_when_daily_requests_exceeded(db):
|
||||
user_id = make_user_and_quota(db, daily_requests=1)
|
||||
check_and_increment_quota(db, user_id, tokens=0, requests=1)
|
||||
assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False
|
||||
|
||||
def test_denied_when_monthly_requests_exceeded(db):
|
||||
user_id = make_user_and_quota(db, monthly_requests=1)
|
||||
check_and_increment_quota(db, user_id, tokens=0, requests=1)
|
||||
assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False
|
||||
|
||||
def test_increments_both_daily_and_monthly_counters(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000,
|
||||
daily_requests=100, monthly_requests=1000)
|
||||
check_and_increment_quota(db, user_id, tokens=50, requests=1)
|
||||
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
assert usage.tokens_used_today == 50
|
||||
assert usage.tokens_used_month == 50
|
||||
assert usage.requests_today == 1
|
||||
assert usage.requests_month == 1
|
||||
|
||||
def test_creates_usage_record_on_first_call(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=1000)
|
||||
assert db.query(Usage).filter(Usage.user_id == user_id).first() is None
|
||||
|
||||
check_and_increment_quota(db, user_id, tokens=10, requests=1)
|
||||
|
||||
assert db.query(Usage).filter(Usage.user_id == user_id).first() is not None
|
||||
|
||||
def test_no_quota_allows_any_request(db):
|
||||
user_id = make_user_and_quota(db) # all limits None
|
||||
assert check_and_increment_quota(db, user_id, tokens=999999, requests=9999) is True
|
||||
|
||||
def test_cumulative_usage_across_calls(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=200)
|
||||
check_and_increment_quota(db, user_id, tokens=100, requests=1)
|
||||
check_and_increment_quota(db, user_id, tokens=99, requests=1)
|
||||
# 199 used, 1 remaining – exactly 1 more token should pass
|
||||
assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is True
|
||||
# Now 200 used – next request must fail
|
||||
assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is False
|
||||
|
||||
|
||||
# --- Reset logic ---
|
||||
|
||||
def test_daily_reset_restores_access(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=100)
|
||||
check_and_increment_quota(db, user_id, tokens=90, requests=1)
|
||||
|
||||
# Backdate daily_reset_at to yesterday
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
usage.daily_reset_at = datetime.utcnow() - timedelta(days=1)
|
||||
db.commit()
|
||||
|
||||
# Should pass again after reset
|
||||
assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True
|
||||
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
assert usage.tokens_used_today == 90
|
||||
|
||||
def test_daily_reset_does_not_affect_monthly_counter(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000)
|
||||
check_and_increment_quota(db, user_id, tokens=50, requests=1)
|
||||
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
usage.daily_reset_at = datetime.utcnow() - timedelta(days=1)
|
||||
db.commit()
|
||||
|
||||
check_and_increment_quota(db, user_id, tokens=50, requests=1)
|
||||
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
assert usage.tokens_used_today == 50
|
||||
assert usage.tokens_used_month == 100 # cumulative across days
|
||||
|
||||
def test_monthly_reset_restores_access(db):
|
||||
user_id = make_user_and_quota(db, monthly_tokens=100)
|
||||
check_and_increment_quota(db, user_id, tokens=90, requests=1)
|
||||
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
usage.monthly_reset_at = datetime.utcnow() - timedelta(days=32)
|
||||
db.commit()
|
||||
|
||||
assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True
|
||||
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
assert usage.tokens_used_month == 90
|
||||
|
||||
def test_failed_quota_check_still_commits_reset(db):
|
||||
user_id = make_user_and_quota(db, daily_tokens=100, daily_requests=5)
|
||||
check_and_increment_quota(db, user_id, tokens=80, requests=1)
|
||||
|
||||
# Backdate so a reset fires, but the new request still exceeds the limit
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
usage.daily_reset_at = datetime.utcnow() - timedelta(days=1)
|
||||
usage.tokens_used_today = 80
|
||||
db.commit()
|
||||
|
||||
# After reset tokens_used_today = 0; 200 tokens exceeds 100 limit
|
||||
result = check_and_increment_quota(db, user_id, tokens=200, requests=1)
|
||||
assert result is False
|
||||
|
||||
# Reset must still be persisted so the next request sees fresh counters
|
||||
db.expire_all()
|
||||
usage = db.query(Usage).filter(Usage.user_id == user_id).first()
|
||||
assert usage.tokens_used_today == 0
|
||||
Loading…
x
Reference in New Issue
Block a user