import pytest import os from datetime import datetime, timedelta, timezone os.environ.setdefault("OLLAMA_URL", "http://127.0.0.1:9999") from database import Base, engine, SessionLocal from models import APIKey, Usage from crud import check_and_increment_quota, count_tokens, create_api_key, verify_api_key def make_api_key(db, daily_tokens=None, monthly_tokens=None, daily_requests=None, monthly_requests=None): db_key, _ = create_api_key( db, name="test-key", daily_tokens=daily_tokens, monthly_tokens=monthly_tokens, daily_requests=daily_requests, monthly_requests=monthly_requests, ) return db_key.id @pytest.fixture def db(): Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) session = SessionLocal() yield session session.close() Base.metadata.drop_all(bind=engine) # --- count_tokens --- def test_count_tokens_empty(): assert count_tokens("") == 0 def test_count_tokens_returns_int(): assert isinstance(count_tokens("hello world"), int) def test_count_tokens_scales_with_length(): assert count_tokens("a b c d e f g h") > count_tokens("a") def test_count_tokens_more_accurate_than_split(): # tiktoken counts "don't" as 2 tokens, split counts it as 1 word text = "don't do that" assert count_tokens(text) >= len(text.split()) # --- check_and_increment_quota --- def test_allowed_within_daily_token_limit(db): api_key_id = make_api_key(db, daily_tokens=1000) assert check_and_increment_quota(db, api_key_id, tokens=100, requests=1) is True def test_denied_when_daily_tokens_exceeded(db): api_key_id = make_api_key(db, daily_tokens=50) assert check_and_increment_quota(db, api_key_id, tokens=100, requests=1) is False def test_denied_when_monthly_tokens_exceeded(db): api_key_id = make_api_key(db, monthly_tokens=50) assert check_and_increment_quota(db, api_key_id, tokens=100, requests=1) is False def test_denied_when_daily_requests_exceeded(db): api_key_id = make_api_key(db, daily_requests=1) check_and_increment_quota(db, api_key_id, tokens=0, requests=1) assert check_and_increment_quota(db, api_key_id, tokens=0, requests=1) is False def test_denied_when_monthly_requests_exceeded(db): api_key_id = make_api_key(db, monthly_requests=1) check_and_increment_quota(db, api_key_id, tokens=0, requests=1) assert check_and_increment_quota(db, api_key_id, tokens=0, requests=1) is False def test_increments_both_daily_and_monthly_counters(db): api_key_id = make_api_key(db, daily_tokens=1000, monthly_tokens=10000, daily_requests=100, monthly_requests=1000) check_and_increment_quota(db, api_key_id, tokens=50, requests=1) usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() assert usage.tokens_used_today == 50 assert usage.tokens_used_month == 50 assert usage.requests_today == 1 assert usage.requests_month == 1 def test_creates_usage_record_on_first_call(db): api_key_id = make_api_key(db, daily_tokens=1000) assert db.query(Usage).filter(Usage.api_key_id == api_key_id).first() is None check_and_increment_quota(db, api_key_id, tokens=10, requests=1) assert db.query(Usage).filter(Usage.api_key_id == api_key_id).first() is not None def test_no_quota_allows_any_request(db): api_key_id = make_api_key(db) # all limits None assert check_and_increment_quota(db, api_key_id, tokens=999999, requests=9999) is True def test_cumulative_usage_across_calls(db): api_key_id = make_api_key(db, daily_tokens=200) check_and_increment_quota(db, api_key_id, tokens=100, requests=1) check_and_increment_quota(db, api_key_id, tokens=99, requests=1) assert check_and_increment_quota(db, api_key_id, tokens=1, requests=1) is True assert check_and_increment_quota(db, api_key_id, tokens=1, requests=1) is False # --- Reset logic --- def test_daily_reset_restores_access(db): api_key_id = make_api_key(db, daily_tokens=100) check_and_increment_quota(db, api_key_id, tokens=90, requests=1) usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() usage.daily_reset_at = datetime.now(timezone.utc) - timedelta(days=1) db.commit() assert check_and_increment_quota(db, api_key_id, tokens=90, requests=1) is True usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() assert usage.tokens_used_today == 90 def test_daily_reset_does_not_affect_monthly_counter(db): api_key_id = make_api_key(db, daily_tokens=1000, monthly_tokens=10000) check_and_increment_quota(db, api_key_id, tokens=50, requests=1) usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() usage.daily_reset_at = datetime.now(timezone.utc) - timedelta(days=1) db.commit() check_and_increment_quota(db, api_key_id, tokens=50, requests=1) usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() assert usage.tokens_used_today == 50 assert usage.tokens_used_month == 100 def test_monthly_reset_restores_access(db): api_key_id = make_api_key(db, monthly_tokens=100) check_and_increment_quota(db, api_key_id, tokens=90, requests=1) usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() usage.monthly_reset_at = datetime.now(timezone.utc) - timedelta(days=32) db.commit() assert check_and_increment_quota(db, api_key_id, tokens=90, requests=1) is True usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() assert usage.tokens_used_month == 90 def test_failed_quota_check_still_commits_reset(db): api_key_id = make_api_key(db, daily_tokens=100, daily_requests=5) check_and_increment_quota(db, api_key_id, tokens=80, requests=1) usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() usage.daily_reset_at = datetime.now(timezone.utc) - timedelta(days=1) usage.tokens_used_today = 80 db.commit() result = check_and_increment_quota(db, api_key_id, tokens=200, requests=1) assert result is False db.expire_all() usage = db.query(Usage).filter(Usage.api_key_id == api_key_id).first() assert usage.tokens_used_today == 0 # --- verify_api_key expiry --- def _create_raw_key(db, expires_at=None): _, raw_key = create_api_key(db, name="expiry-test", expires_at=expires_at) return raw_key def test_key_without_expiry_is_valid(db): raw_key = _create_raw_key(db) assert verify_api_key(db, raw_key) is not None def test_key_with_future_expiry_is_valid(db): future = datetime.now(timezone.utc) + timedelta(days=30) raw_key = _create_raw_key(db, expires_at=future) assert verify_api_key(db, raw_key) is not None def test_key_with_past_expiry_is_rejected(db): past = datetime.now(timezone.utc) - timedelta(seconds=1) raw_key = _create_raw_key(db, expires_at=past) assert verify_api_key(db, raw_key) is None