import os import secrets import hashlib import tiktoken from datetime import datetime, timezone from zoneinfo import ZoneInfo from sqlalchemy.orm import Session from models import APIKey, Usage, Setting _encoder = tiktoken.get_encoding("cl100k_base") _tz = ZoneInfo(os.getenv("APP_TZ", "Europe/Berlin")) def _now_local() -> datetime: return datetime.now(_tz) def _to_local(dt: datetime) -> datetime: if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt.astimezone(_tz) def count_tokens(text: str) -> int: return len(_encoder.encode(text)) def _hash_api_key(key: str) -> str: return hashlib.sha256(key.encode()).hexdigest() def generate_api_key() -> str: return "sk-" + secrets.token_urlsafe(32) def create_api_key( db: Session, name: str, expires_at: datetime = None, daily_tokens: int = None, monthly_tokens: int = None, daily_requests: int = None, monthly_requests: int = None, ) -> tuple[APIKey, str]: raw_key = generate_api_key() db_key = APIKey( name=name, key=_hash_api_key(raw_key), key_prefix=raw_key[:12], expires_at=expires_at, daily_tokens=daily_tokens, monthly_tokens=monthly_tokens, daily_requests=daily_requests, monthly_requests=monthly_requests, ) db.add(db_key) db.commit() db.refresh(db_key) return db_key, raw_key def get_setting(db: Session, key: str, default: str = None) -> str: row = db.query(Setting).filter(Setting.key == key).first() return row.value if row else default def set_setting(db: Session, key: str, value: str) -> None: row = db.query(Setting).filter(Setting.key == key).first() if row: row.value = value else: db.add(Setting(key=key, value=value)) db.commit() def verify_api_key(db: Session, api_key: str): key_hash = _hash_api_key(api_key) db_key = db.query(APIKey).filter(APIKey.key == key_hash, APIKey.is_active == True).first() if db_key and db_key.expires_at: expires = db_key.expires_at if expires.tzinfo is None: expires = expires.replace(tzinfo=timezone.utc) if expires < datetime.now(timezone.utc): return None return db_key def check_and_increment_quota(db: Session, api_key_id: int, tokens: int = 0, requests: int = 1) -> bool: usage = ( db.query(Usage) .filter(Usage.api_key_id == api_key_id) .with_for_update() .first() ) if not usage: usage = Usage(api_key_id=api_key_id) db.add(usage) db.flush() now = _now_local() daily_reset_local = _to_local(usage.daily_reset_at) monthly_reset_local = _to_local(usage.monthly_reset_at) if daily_reset_local.date() < now.date(): usage.tokens_used_today = 0 usage.requests_today = 0 usage.daily_reset_at = now if (monthly_reset_local.year, monthly_reset_local.month) < (now.year, now.month): usage.tokens_used_month = 0 usage.requests_month = 0 usage.monthly_reset_at = now api_key = db.query(APIKey).filter(APIKey.id == api_key_id).first() allowed = True if api_key: if api_key.daily_tokens and (usage.tokens_used_today + tokens) > api_key.daily_tokens: allowed = False elif api_key.monthly_tokens and (usage.tokens_used_month + tokens) > api_key.monthly_tokens: allowed = False elif api_key.daily_requests and (usage.requests_today + requests) > api_key.daily_requests: allowed = False elif api_key.monthly_requests and (usage.requests_month + requests) > api_key.monthly_requests: allowed = False if allowed: usage.tokens_used_today += tokens usage.tokens_used_month += tokens usage.requests_today += requests usage.requests_month += requests db.commit() return allowed