import secrets import hashlib import bcrypt import tiktoken from datetime import datetime, timezone from sqlalchemy.orm import Session from models import APIKey, User, Quota, Usage _encoder = tiktoken.get_encoding("cl100k_base") def count_tokens(text: str) -> int: return len(_encoder.encode(text)) def hash_password(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() def verify_password(plain_password: str, hashed_password: str) -> bool: return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) def _hash_api_key(key: str) -> str: return hashlib.sha256(key.encode()).hexdigest() def get_user_by_username(db: Session, username: str): return db.query(User).filter(User.username == username).first() def get_user_by_email(db: Session, email: str): return db.query(User).filter(User.email == email).first() def generate_api_key(): return "sk-" + secrets.token_urlsafe(32) def create_user(db: Session, username: str, email: str, password: str, is_admin: bool = False): db_user = User( username=username, email=email, hashed_password=hash_password(password), is_admin=is_admin, ) db.add(db_user) db.commit() db.refresh(db_user) default_quota = Quota( user_id=db_user.id, daily_tokens=1000000, monthly_tokens=10000000, daily_requests=1000, monthly_requests=10000 ) db.add(default_quota) db.commit() return db_user def create_api_key(db: Session, user_id: int, name: str) -> tuple[APIKey, str]: raw_key = generate_api_key() db_key = APIKey( name=name, key=_hash_api_key(raw_key), user_id=user_id ) db.add(db_key) db.commit() db.refresh(db_key) return db_key, raw_key def verify_api_key(db: Session, api_key: str): key_hash = _hash_api_key(api_key) return db.query(APIKey).filter(APIKey.key == key_hash, APIKey.is_active == True).first() def get_quota(db: Session, user_id: int): return db.query(Quota).filter(Quota.user_id == user_id).first() def get_quota_by_user_id(db: Session, user_id: int): return db.query(Quota).filter(Quota.user_id == user_id).first() def get_usage(db: Session, user_id: int): return db.query(Usage).filter(Usage.user_id == user_id).first() def check_and_increment_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1) -> bool: usage = ( db.query(Usage) .filter(Usage.user_id == user_id) .with_for_update() .first() ) if not usage: usage = Usage(user_id=user_id) db.add(usage) db.flush() now = datetime.now(timezone.utc) if usage.daily_reset_at.date() < now.date(): usage.tokens_used_today = 0 usage.requests_today = 0 usage.daily_reset_at = now if (usage.monthly_reset_at.year, usage.monthly_reset_at.month) < (now.year, now.month): usage.tokens_used_month = 0 usage.requests_month = 0 usage.monthly_reset_at = now quota = get_quota(db, user_id) allowed = True if quota: if quota.daily_tokens and (usage.tokens_used_today + tokens) > quota.daily_tokens: allowed = False elif quota.monthly_tokens and (usage.tokens_used_month + tokens) > quota.monthly_tokens: allowed = False elif quota.daily_requests and (usage.requests_today + requests) > quota.daily_requests: allowed = False elif quota.monthly_requests and (usage.requests_month + requests) > quota.monthly_requests: allowed = False if allowed: usage.tokens_used_today += tokens usage.tokens_used_month += tokens usage.requests_today += requests usage.requests_month += requests db.commit() return allowed