Embed APP_VERSION build arg in Docker image (default: dev). build_push.sh passes the git tag as build arg. Proxy exposes GET /version, admin UI shows it as read-only field in settings.
226 lines
10 KiB
Python
226 lines
10 KiB
Python
import logging
|
|
import os
|
|
from logging.handlers import RotatingFileHandler
|
|
from pathlib import Path
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from sqlalchemy.orm import Session
|
|
from database import get_db
|
|
import crud
|
|
import httpx
|
|
|
|
_log_dir = Path(os.getenv("LOG_FILE", "logs/usage.log")).parent
|
|
_log_dir.mkdir(parents=True, exist_ok=True)
|
|
_fmt = logging.Formatter("%(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
|
|
# Rotating usage log (8 KB per file, 3 backups)
|
|
_usage_handler = RotatingFileHandler(str(_log_dir / "usage.log"), maxBytes=8192, backupCount=3, encoding="utf-8")
|
|
_usage_handler.setFormatter(_fmt)
|
|
usage_log = logging.getLogger("proxy.usage")
|
|
usage_log.setLevel(logging.INFO)
|
|
usage_log.addHandler(_usage_handler)
|
|
usage_log.propagate = False
|
|
|
|
# Rotating error log (64 KB per file, 5 backups)
|
|
_error_handler = RotatingFileHandler(str(_log_dir / "error.log"), maxBytes=65536, backupCount=5, encoding="utf-8")
|
|
_error_handler.setFormatter(_fmt)
|
|
error_log = logging.getLogger("proxy.error")
|
|
error_log.setLevel(logging.ERROR)
|
|
error_log.addHandler(_error_handler)
|
|
error_log.propagate = False
|
|
|
|
def _content_to_str(content) -> str:
|
|
"""Normalize OpenAI content: string or array of content parts → plain string."""
|
|
if isinstance(content, list):
|
|
return " ".join(
|
|
part.get("text", "") if isinstance(part, dict) else str(part)
|
|
for part in content
|
|
)
|
|
return content or ""
|
|
|
|
|
|
def _last_user_msg(messages: list, max_len: int = 120) -> str:
|
|
for msg in reversed(messages):
|
|
if msg.get("role") == "user":
|
|
text = _content_to_str(msg.get("content")).replace("\n", " ").strip()
|
|
return text[:max_len] + ("…" if len(text) > max_len else "")
|
|
return ""
|
|
|
|
async def require_api_key(request: Request, db: Session = Depends(get_db)):
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
api_key = auth_header[7:]
|
|
elif auth_header.startswith("sk-"):
|
|
api_key = auth_header
|
|
else:
|
|
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
|
db_key = crud.verify_api_key(db, api_key)
|
|
if not db_key:
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
request.state.api_key_id = db_key.id
|
|
request.state.api_key_name = db_key.name
|
|
|
|
app = FastAPI(title="Ollama Proxy", dependencies=[Depends(require_api_key)])
|
|
|
|
@app.on_event("startup")
|
|
def apply_env_settings():
|
|
"""Write env-configured values into DB so they take effect until next restart."""
|
|
db = next(get_db())
|
|
try:
|
|
if url := os.getenv("OLLAMA_URL"):
|
|
crud.set_setting(db, "ollama_url", url)
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
@app.exception_handler(Exception)
|
|
async def unhandled_exception_handler(request: Request, exc: Exception):
|
|
error_log.error("Unhandled exception | %s %s | %s: %s",
|
|
request.method, request.url.path, type(exc).__name__, exc, exc_info=exc)
|
|
return JSONResponse(status_code=500, content={"error": {"message": "Internal server error", "type": "server_error"}})
|
|
|
|
async def proxy_request(url: str, method: str = "GET", json_data: dict = None):
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
response = await client.request(method=method, url=url, json=json_data)
|
|
return response
|
|
|
|
@app.post("/api/generate")
|
|
async def generate(request: Request, db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
body = await request.json()
|
|
force_model = crud.get_setting(db, "force_model") or None
|
|
if force_model:
|
|
body = {**body, "model": force_model}
|
|
if not body.get("model"):
|
|
raise HTTPException(status_code=422, detail="Field 'model' is required")
|
|
prompt_tokens = crud.count_tokens(body.get("prompt", ""))
|
|
|
|
if not crud.check_and_increment_quota(db, request.state.api_key_id, tokens=prompt_tokens, requests=1):
|
|
raise HTTPException(status_code=429, detail="Quota exceeded")
|
|
|
|
prompt_preview = (body.get("prompt", "").replace("\n", " ").strip())[:120]
|
|
usage_log.info('%s | /api/generate | %s | ~%d tokens | "%s"',
|
|
request.state.api_key_name, body.get("model", "?"), prompt_tokens, prompt_preview)
|
|
try:
|
|
response = await proxy_request(f"{ollama_url}/api/generate", method="POST", json_data=body)
|
|
resp_json = response.json()
|
|
usage_log.info('%s | /api/generate | %s | actual ↑%d ↓%d tokens',
|
|
request.state.api_key_name, body.get("model", "?"),
|
|
resp_json.get("prompt_eval_count", 0), resp_json.get("eval_count", 0))
|
|
return JSONResponse(content=resp_json, status_code=response.status_code)
|
|
except Exception as exc:
|
|
error_log.error("Proxy error | %s | /api/generate | %s | %s: %s",
|
|
request.state.api_key_name, body.get("model", "?"), type(exc).__name__, exc, exc_info=exc)
|
|
raise
|
|
|
|
@app.post("/api/chat")
|
|
async def chat(request: Request, db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
body = await request.json()
|
|
force_model = crud.get_setting(db, "force_model") or None
|
|
if force_model:
|
|
body = {**body, "model": force_model}
|
|
if not body.get("model"):
|
|
raise HTTPException(status_code=422, detail="Field 'model' is required")
|
|
messages = body.get("messages", [])
|
|
prompt_tokens = sum(crud.count_tokens(_content_to_str(msg.get("content"))) for msg in messages)
|
|
|
|
if not crud.check_and_increment_quota(db, request.state.api_key_id, tokens=prompt_tokens, requests=1):
|
|
raise HTTPException(status_code=429, detail="Quota exceeded")
|
|
|
|
usage_log.info('%s | /api/chat | %s | ~%d tokens | "%s"',
|
|
request.state.api_key_name, body.get("model", "?"), prompt_tokens, _last_user_msg(messages))
|
|
try:
|
|
response = await proxy_request(f"{ollama_url}/api/chat", method="POST", json_data=body)
|
|
resp_json = response.json()
|
|
usage_log.info('%s | /api/chat | %s | actual ↑%d ↓%d tokens',
|
|
request.state.api_key_name, body.get("model", "?"),
|
|
resp_json.get("prompt_eval_count", 0), resp_json.get("eval_count", 0))
|
|
return JSONResponse(content=resp_json, status_code=response.status_code)
|
|
except Exception as exc:
|
|
error_log.error("Proxy error | %s | /api/chat | %s | %s: %s",
|
|
request.state.api_key_name, body.get("model", "?"), type(exc).__name__, exc, exc_info=exc)
|
|
raise
|
|
|
|
@app.get("/api/tags")
|
|
async def list_models(db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
response = await proxy_request(f"{ollama_url}/api/tags", method="GET")
|
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
|
|
|
@app.get("/version")
|
|
async def version():
|
|
return {"version": os.getenv("APP_VERSION", "dev")}
|
|
|
|
@app.get("/api/ps")
|
|
async def running_models(db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
response = await proxy_request(f"{ollama_url}/api/ps", method="GET")
|
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
|
|
|
@app.get("/api/versions")
|
|
async def versions(db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
response = await proxy_request(f"{ollama_url}/api/versions", method="GET")
|
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
|
|
|
@app.get("/v1/models")
|
|
async def list_openai_models(db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
response = await proxy_request(f"{ollama_url}/v1/models", method="GET")
|
|
return JSONResponse(content=response.json(), status_code=response.status_code)
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def openai_chat_completions(request: Request, db: Session = Depends(get_db)):
|
|
ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434"))
|
|
|
|
body = await request.json()
|
|
force_model = crud.get_setting(db, "force_model") or None
|
|
if force_model:
|
|
body = {**body, "model": force_model}
|
|
if not body.get("model"):
|
|
raise HTTPException(status_code=422, detail="Field 'model' is required")
|
|
messages = body.get("messages", [])
|
|
prompt_tokens = sum(crud.count_tokens(_content_to_str(msg.get("content"))) for msg in messages)
|
|
|
|
if not crud.check_and_increment_quota(db, request.state.api_key_id, tokens=prompt_tokens, requests=1):
|
|
raise HTTPException(status_code=429, detail="Quota exceeded")
|
|
|
|
model_name = body["model"]
|
|
|
|
usage_log.info('%s | /v1/chat/completions | %s | ~%d tokens | "%s"',
|
|
request.state.api_key_name, model_name, prompt_tokens, _last_user_msg(messages))
|
|
|
|
target = f"{ollama_url}/v1/chat/completions"
|
|
|
|
if body.get("stream"):
|
|
async def generate():
|
|
try:
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
async with client.stream("POST", target, json=body) as resp:
|
|
async for chunk in resp.aiter_bytes():
|
|
yield chunk
|
|
except Exception as exc:
|
|
error_log.error("Stream error | %s | /v1/chat/completions | %s | %s: %s",
|
|
request.state.api_key_name, model_name, type(exc).__name__, exc, exc_info=exc)
|
|
raise
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
try:
|
|
response = await proxy_request(target, method="POST", json_data=body)
|
|
resp_json = response.json()
|
|
usage = resp_json.get("usage", {})
|
|
usage_log.info('%s | /v1/chat/completions | %s | actual ↑%d ↓%d tokens',
|
|
request.state.api_key_name, model_name,
|
|
usage.get("prompt_tokens", 0), usage.get("completion_tokens", 0))
|
|
return JSONResponse(content=resp_json, status_code=response.status_code)
|
|
except Exception as exc:
|
|
error_log.error("Proxy error | %s | /v1/chat/completions | %s | %s: %s",
|
|
request.state.api_key_name, model_name, type(exc).__name__, exc, exc_info=exc)
|
|
raise
|