import pytest import os from datetime import datetime, timedelta os.environ.setdefault("OLLAMA_URL", "http://127.0.0.1:9999") from database import Base, engine, SessionLocal from models import User, Quota, Usage from crud import check_and_increment_quota, count_tokens, hash_password def make_user_and_quota(db, daily_tokens=None, monthly_tokens=None, daily_requests=None, monthly_requests=None): user = User( username="quotauser", email="quota@example.com", hashed_password=hash_password("pass"), is_active=True, ) db.add(user) db.commit() db.refresh(user) quota = Quota( user_id=user.id, daily_tokens=daily_tokens, monthly_tokens=monthly_tokens, daily_requests=daily_requests, monthly_requests=monthly_requests, ) db.add(quota) db.commit() return user.id @pytest.fixture def db(): Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) session = SessionLocal() yield session session.close() Base.metadata.drop_all(bind=engine) # --- count_tokens --- def test_count_tokens_empty(): assert count_tokens("") == 0 def test_count_tokens_returns_int(): assert isinstance(count_tokens("hello world"), int) def test_count_tokens_scales_with_length(): assert count_tokens("a b c d e f g h") > count_tokens("a") def test_count_tokens_more_accurate_than_split(): # tiktoken counts "don't" as 2 tokens, split counts it as 1 word text = "don't do that" assert count_tokens(text) >= len(text.split()) # --- check_and_increment_quota --- def test_allowed_within_daily_token_limit(db): user_id = make_user_and_quota(db, daily_tokens=1000) assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is True def test_denied_when_daily_tokens_exceeded(db): user_id = make_user_and_quota(db, daily_tokens=50) assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False def test_denied_when_monthly_tokens_exceeded(db): user_id = make_user_and_quota(db, monthly_tokens=50) assert check_and_increment_quota(db, user_id, tokens=100, requests=1) is False def test_denied_when_daily_requests_exceeded(db): user_id = make_user_and_quota(db, daily_requests=1) check_and_increment_quota(db, user_id, tokens=0, requests=1) assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False def test_denied_when_monthly_requests_exceeded(db): user_id = make_user_and_quota(db, monthly_requests=1) check_and_increment_quota(db, user_id, tokens=0, requests=1) assert check_and_increment_quota(db, user_id, tokens=0, requests=1) is False def test_increments_both_daily_and_monthly_counters(db): user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000, daily_requests=100, monthly_requests=1000) check_and_increment_quota(db, user_id, tokens=50, requests=1) usage = db.query(Usage).filter(Usage.user_id == user_id).first() assert usage.tokens_used_today == 50 assert usage.tokens_used_month == 50 assert usage.requests_today == 1 assert usage.requests_month == 1 def test_creates_usage_record_on_first_call(db): user_id = make_user_and_quota(db, daily_tokens=1000) assert db.query(Usage).filter(Usage.user_id == user_id).first() is None check_and_increment_quota(db, user_id, tokens=10, requests=1) assert db.query(Usage).filter(Usage.user_id == user_id).first() is not None def test_no_quota_allows_any_request(db): user_id = make_user_and_quota(db) # all limits None assert check_and_increment_quota(db, user_id, tokens=999999, requests=9999) is True def test_cumulative_usage_across_calls(db): user_id = make_user_and_quota(db, daily_tokens=200) check_and_increment_quota(db, user_id, tokens=100, requests=1) check_and_increment_quota(db, user_id, tokens=99, requests=1) # 199 used, 1 remaining – exactly 1 more token should pass assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is True # Now 200 used – next request must fail assert check_and_increment_quota(db, user_id, tokens=1, requests=1) is False # --- Reset logic --- def test_daily_reset_restores_access(db): user_id = make_user_and_quota(db, daily_tokens=100) check_and_increment_quota(db, user_id, tokens=90, requests=1) # Backdate daily_reset_at to yesterday usage = db.query(Usage).filter(Usage.user_id == user_id).first() usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) db.commit() # Should pass again after reset assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True usage = db.query(Usage).filter(Usage.user_id == user_id).first() assert usage.tokens_used_today == 90 def test_daily_reset_does_not_affect_monthly_counter(db): user_id = make_user_and_quota(db, daily_tokens=1000, monthly_tokens=10000) check_and_increment_quota(db, user_id, tokens=50, requests=1) usage = db.query(Usage).filter(Usage.user_id == user_id).first() usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) db.commit() check_and_increment_quota(db, user_id, tokens=50, requests=1) usage = db.query(Usage).filter(Usage.user_id == user_id).first() assert usage.tokens_used_today == 50 assert usage.tokens_used_month == 100 # cumulative across days def test_monthly_reset_restores_access(db): user_id = make_user_and_quota(db, monthly_tokens=100) check_and_increment_quota(db, user_id, tokens=90, requests=1) usage = db.query(Usage).filter(Usage.user_id == user_id).first() usage.monthly_reset_at = datetime.utcnow() - timedelta(days=32) db.commit() assert check_and_increment_quota(db, user_id, tokens=90, requests=1) is True usage = db.query(Usage).filter(Usage.user_id == user_id).first() assert usage.tokens_used_month == 90 def test_failed_quota_check_still_commits_reset(db): user_id = make_user_and_quota(db, daily_tokens=100, daily_requests=5) check_and_increment_quota(db, user_id, tokens=80, requests=1) # Backdate so a reset fires, but the new request still exceeds the limit usage = db.query(Usage).filter(Usage.user_id == user_id).first() usage.daily_reset_at = datetime.utcnow() - timedelta(days=1) usage.tokens_used_today = 80 db.commit() # After reset tokens_used_today = 0; 200 tokens exceeds 100 limit result = check_and_increment_quota(db, user_id, tokens=200, requests=1) assert result is False # Reset must still be persisted so the next request sees fresh counters db.expire_all() usage = db.query(Usage).filter(Usage.user_id == user_id).first() assert usage.tokens_used_today == 0