96 lines
2.8 KiB
Python
96 lines
2.8 KiB
Python
import secrets
|
|
import hashlib
|
|
from datetime import datetime, timedelta
|
|
from sqlalchemy.orm import Session
|
|
from database import get_db
|
|
from models import APIKey, User, Quota, Usage
|
|
|
|
def get_user_by_username(db: Session, username: str):
|
|
return db.query(User).filter(User.username == username).first()
|
|
|
|
def get_user_by_email(db: Session, email: str):
|
|
return db.query(User).filter(User.email == email).first()
|
|
|
|
def generate_api_key():
|
|
return "sk-" + secrets.token_urlsafe(32)
|
|
|
|
def hash_password(password: str):
|
|
return hashlib.sha256(password.encode()).hexdigest()
|
|
|
|
def create_user(db: Session, username: str, email: str, password: str):
|
|
db_user = User(
|
|
username=username,
|
|
email=email,
|
|
hashed_password=hash_password(password)
|
|
)
|
|
db.add(db_user)
|
|
db.commit()
|
|
db.refresh(db_user)
|
|
|
|
default_quota = Quota(
|
|
user_id=db_user.id,
|
|
daily_tokens=1000000,
|
|
monthly_tokens=10000000,
|
|
daily_requests=1000,
|
|
monthly_requests=10000
|
|
)
|
|
db.add(default_quota)
|
|
db.commit()
|
|
|
|
return db_user
|
|
|
|
def create_api_key(db: Session, user_id: int, name: str):
|
|
key = generate_api_key()
|
|
db_key = APIKey(
|
|
name=name,
|
|
key=key,
|
|
user_id=user_id
|
|
)
|
|
db.add(db_key)
|
|
db.commit()
|
|
db.refresh(db_key)
|
|
return db_key
|
|
|
|
def verify_api_key(db: Session, api_key: str):
|
|
return db.query(APIKey).filter(APIKey.key == api_key, APIKey.is_active == True).first()
|
|
|
|
def get_quota(db: Session, user_id: int):
|
|
return db.query(Quota).filter(Quota.user_id == user_id).first()
|
|
|
|
def check_quota(db: Session, user_id: int, tokens: int = 0, requests: int = 1):
|
|
quota = get_quota(db, user_id)
|
|
usage = get_usage(db, user_id)
|
|
|
|
if quota.daily_tokens and (usage.tokens_used + tokens) > quota.daily_tokens:
|
|
return False
|
|
if quota.monthly_tokens and (usage.tokens_used + tokens) > quota.monthly_tokens:
|
|
return False
|
|
if quota.daily_requests and (usage.requests_count + requests) > quota.daily_requests:
|
|
return False
|
|
if quota.monthly_requests and (usage.requests_count + requests) > quota.monthly_requests:
|
|
return False
|
|
|
|
return True
|
|
|
|
def increment_usage(db: Session, user_id: int, tokens: int = 0, requests: int = 1):
|
|
usage = get_or_create_usage(db, user_id)
|
|
usage.tokens_used += tokens
|
|
usage.requests_count += requests
|
|
db.commit()
|
|
db.refresh(usage)
|
|
|
|
def get_usage(db: Session, user_id: int):
|
|
return db.query(Usage).filter(Usage.user_id == user_id).first()
|
|
|
|
def get_quota_by_user_id(db: Session, user_id: int):
|
|
return db.query(Quota).filter(Quota.user_id == user_id).first()
|
|
|
|
def get_or_create_usage(db: Session, user_id: int):
|
|
usage = get_usage(db, user_id)
|
|
if not usage:
|
|
usage = Usage(user_id=user_id)
|
|
db.add(usage)
|
|
db.commit()
|
|
db.refresh(usage)
|
|
return usage
|