Removes DEFAULT_MODEL in favour of a force_model setting configurable via the admin UI. When set, every proxy request's model field is overridden, preventing uncoordinated model switches during lab sessions. Updates schemas, admin API, all three proxy endpoints, frontend, init_db, and docs (README, DOCKERHUB, KURZANLEITUNG).
210 lines
9.8 KiB
Python
210 lines
9.8 KiB
Python
import logging
|
|
import os
|
|
from logging.handlers import RotatingFileHandler
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from sqlalchemy.orm import Session
|
|
from database import get_db
|
|
import crud
|
|
import httpx
|
|
|
|
_log_dir = Path(os.getenv("LOG_FILE", "logs/usage.log")).parent
|
|
_log_dir.mkdir(parents=True, exist_ok=True)
|
|
_fmt = logging.Formatter("%(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
|
|
# Rotating usage log (8 KB per file, 3 backups)
|
|
_usage_handler = RotatingFileHandler(str(_log_dir / "usage.log"), maxBytes=8192, backupCount=3, encoding="utf-8")
|
|
_usage_handler.setFormatter(_fmt)
|
|
usage_log = logging.getLogger("proxy.usage")
|
|
usage_log.setLevel(logging.INFO)
|
|
usage_log.addHandler(_usage_handler)
|
|
usage_log.propagate = False
|
|
|
|
# Rotating error log (64 KB per file, 5 backups)
|
|
_error_handler = RotatingFileHandler(str(_log_dir / "error.log"), maxBytes=65536, backupCount=5, encoding="utf-8")
|
|
_error_handler.setFormatter(_fmt)
|
|
error_log = logging.getLogger("proxy.error")
|
|
error_log.setLevel(logging.ERROR)
|
|
error_log.addHandler(_error_handler)
|
|
error_log.propagate = False
|
|
|
|
def _content_to_str(content) -> str:
|
|
"""Normalize OpenAI content: string or array of content parts → plain string."""
|
|
if isinstance(content, list):
|
|
return " ".join(
|
|
part.get("text", "") if isinstance(part, dict) else str(part)
|
|
for part in content
|
|
)
|
|
return content or ""
|
|
|
|
|
|
def _last_user_msg(messages: list, max_len: int = 120) -> str:
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
text = _content_to_str(msg.get("content")).replace("\n", " ").strip()
|
|
return text[:max_len] + ("…" if len(text) > max_len else "")
|
|
return ""
|
|
|
|
async def require_api_key(request: Request, db: Session = Depends(get_db)):
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
api_key = auth_header[7:]
|
|
elif auth_header.startswith("sk-"):
|
|
api_key = auth_header
|
|
else:
|
|
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
|
db_key = crud.verify_api_key(db, api_key)
|
|
if not db_key:
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
request.state.api_key_id = db_key.id
|
|
request.state.api_key_name = db_key.name
|
|
|
|
app = FastAPI(title="Ollama Proxy", dependencies=[Depends(require_api_key)])
|
|
|
|
@app.on_event("startup")
|
|
def apply_env_settings():
|
|
"""Write env-configured values into DB so they take effect until next restart."""
|
|
db = next(get_db())
|
|
try:
|
|
if url := os.getenv("OLLAMA_URL"):
|
|
crud.set_setting(db, "ollama_url", url)
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
@app.exception_handler(Exception)
|
|
async def unhandled_exception_handler(request: Request, exc: Exception):
|
|
error_log.error("Unhandled exception | %s %s | %s: %s",
|
|
request.method, request.url.path, type(exc).__name__, exc, exc_info=exc)
|
|
return JSONResponse(status_code=500, content={"error": {"message": "Internal server error", "type": "server_error"}})
|
|
|
|
async def proxy_request(url: str, method: str = "GET", json_data: dict = None):
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
response = await client.request(method=method, url=url, json=json_data)
|
|
return response
|
|
|
|
@app.post("/api/generate")
|
|
async def generate(request: Request, db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
body = await request.json()
|
|
force_model = crud.get_setting(db, "force_model") or None
|
|
if force_model:
|
|
body = {**body, "model": force_model}
|
|
prompt_tokens = crud.count_tokens(body.get("prompt", ""))
|
|
|
|
if not crud.check_and_increment_quota(db, request.state.api_key_id, tokens=prompt_tokens, requests=1):
|
|
raise HTTPException(status_code=429, detail="Quota exceeded")
|
|
|
|
prompt_preview = (body.get("prompt", "").replace("\n", " ").strip())[:120]
|
|
usage_log.info('%s | /api/generate | %s | ~%d tokens | "%s"',
|
|
request.state.api_key_name, body.get("model", "?"), prompt_tokens, prompt_preview)
|
|
try:
|
|
response = await proxy_request(f"{ollama_url}/api/generate", method="POST", json_data=body)
|
|
resp_json = response.json()
|
|
usage_log.info('%s | /api/generate | %s | actual ↑%d ↓%d tokens',
|
|
request.state.api_key_name, body.get("model", "?"),
|
|
resp_json.get("prompt_eval_count", 0), resp_json.get("eval_count", 0))
|
|
return JSONResponse(content=resp_json, status_code=response.status_code)
|
|
except Exception as exc:
|
|
error_log.error("Proxy error | %s | /api/generate | %s | %s: %s",
|
|
request.state.api_key_name, body.get("model", "?"), type(exc).__name__, exc, exc_info=exc)
|
|
raise
|
|
|
|
@app.post("/api/chat")
|
|
async def chat(request: Request, db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
body = await request.json()
|
|
force_model = crud.get_setting(db, "force_model") or None
|
|
if force_model:
|
|
body = {**body, "model": force_model}
|
|
messages = body.get("messages", [])
|
|
prompt_tokens = sum(crud.count_tokens(_content_to_str(msg.get("content"))) for msg in messages)
|
|
|
|
if not crud.check_and_increment_quota(db, request.state.api_key_id, tokens=prompt_tokens, requests=1):
|
|
raise HTTPException(status_code=429, detail="Quota exceeded")
|
|
|
|
usage_log.info('%s | /api/chat | %s | ~%d tokens | "%s"',
|
|
request.state.api_key_name, body.get("model", "?"), prompt_tokens, _last_user_msg(messages))
|
|
try:
|
|
response = await proxy_request(f"{ollama_url}/api/chat", method="POST", json_data=body)
|
|
resp_json = response.json()
|
|
usage_log.info('%s | /api/chat | %s | actual ↑%d ↓%d tokens',
|
|
request.state.api_key_name, body.get("model", "?"),
|
|
resp_json.get("prompt_eval_count", 0), resp_json.get("eval_count", 0))
|
|
return JSONResponse(content=resp_json, status_code=response.status_code)
|
|
except Exception as exc:
|
|
error_log.error("Proxy error | %s | /api/chat | %s | %s: %s",
|
|
request.state.api_key_name, body.get("model", "?"), type(exc).__name__, exc, exc_info=exc)
|
|
raise
|
|
|
|
@app.get("/api/tags")
|
|
async def list_models(db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
response = await proxy_request(f"{ollama_url}/api/tags", method="GET")
|
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
|
|
|
@app.get("/api/versions")
|
|
async def versions(db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
response = await proxy_request(f"{ollama_url}/api/versions", method="GET")
|
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
|
|
|
@app.get("/v1/models")
|
|
async def list_openai_models(db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
response = await proxy_request(f"{ollama_url}/v1/models", method="GET")
|
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def openai_chat_completions(request: Request, db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
|
|
body = await request.json()
|
|
force_model = crud.get_setting(db, "force_model") or None
|
|
if force_model:
|
|
body = {**body, "model": force_model}
|
|
messages = body.get("messages", [])
|
|
prompt_tokens = sum(crud.count_tokens(_content_to_str(msg.get("content"))) for msg in messages)
|
|
|
|
if not crud.check_and_increment_quota(db, request.state.api_key_id, tokens=prompt_tokens, requests=1):
|
|
raise HTTPException(status_code=429, detail="Quota exceeded")
|
|
|
|
model_name = body.get("model", "?")
|
|
|
|
usage_log.info('%s | /v1/chat/completions | %s | ~%d tokens | "%s"',
|
|
request.state.api_key_name, model_name, prompt_tokens, _last_user_msg(messages))
|
|
|
|
target = f"{ollama_url}/v1/chat/completions"
|
|
|
|
if body.get("stream"):
|
|
async def generate():
|
|
try:
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
async with client.stream("POST", target, json=body) as resp:
|
|
async for chunk in resp.aiter_bytes():
|
|
yield chunk
|
|
except Exception as exc:
|
|
error_log.error("Stream error | %s | /v1/chat/completions | %s | %s: %s",
|
|
request.state.api_key_name, model_name, type(exc).__name__, exc, exc_info=exc)
|
|
raise
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
try:
|
|
response = await proxy_request(target, method="POST", json_data=body)
|
|
resp_json = response.json()
|
|
usage = resp_json.get("usage", {})
|
|
usage_log.info('%s | /v1/chat/completions | %s | actual ↑%d ↓%d tokens',
|
|
request.state.api_key_name, model_name,
|
|
usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0))
|
|
return JSONResponse(content=resp_json, status_code=response.status_code)
|
|
except Exception as exc:
|
|
error_log.error("Proxy error | %s | /v1/chat/completions | %s | %s: %s",
|
|
request.state.api_key_name, model_name, type(exc).__name__, exc, exc_info=exc)
|
|
raise
|