import time import uuid from fastapi import FastAPI, HTTPException, Depends, Request from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from database import get_db, SessionLocal import crud import httpx import os app = FastAPI(title="Ollama Proxy") async def proxy_request(url: str, method: str = "GET", json_data: dict = None, headers: dict = None): async with httpx.AsyncClient() as client: response = await client.request(method=method, url=url, json=json_data, headers=headers) return response @app.middleware("http") async def authenticate_and_quota(request: Request, call_next): auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): api_key = auth_header.replace("Bearer ", "") elif auth_header.startswith("sk-"): api_key = auth_header else: return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"}) # Uses its own session since middleware cannot use Depends db = SessionLocal() try: db_key = crud.verify_api_key(db, api_key) if not db_key: return JSONResponse(status_code=401, content={"detail": "Invalid API key"}) request.state.api_key_id = db_key.id finally: db.close() response = await call_next(request) return response @app.post("/api/generate") async def generate(request: Request, db: Session = Depends(get_db)): api_key_id = request.state.api_key_id ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434")) body = await request.json() prompt_tokens = crud.count_tokens(body.get("prompt", "")) if not crud.check_and_increment_quota(db, api_key_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") response = await proxy_request(f"{ollama_url}/api/generate", method="POST", json_data=body, headers=dict(request.headers)) return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.post("/api/chat") async def chat(request: Request, db: Session = Depends(get_db)): api_key_id = request.state.api_key_id ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434")) body = await request.json() prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in body.get("messages", [])) if not crud.check_and_increment_quota(db, api_key_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") response = await proxy_request(f"{ollama_url}/api/chat", method="POST", json_data=body, headers=dict(request.headers)) return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.get("/api/tags") async def list_models(request: Request, db: Session = Depends(get_db)): ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434")) response = await proxy_request(f"{ollama_url}/api/tags", method="GET", headers=dict(request.headers)) return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.get("/api/versions") async def versions(request: Request, db: Session = Depends(get_db)): ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434")) response = await proxy_request(f"{ollama_url}/api/versions", method="GET", headers=dict(request.headers)) return JSONResponse(content=response.json(), status_code=response.status_code, headers=dict(response.headers)) @app.get("/v1/models") async def list_openai_models(request: Request, db: Session = Depends(get_db)): ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434")) response = await proxy_request(f"{ollama_url}/api/tags", method="GET", headers=dict(request.headers)) ollama_models = response.json() openai_models = { "object": "list", "data": [ { "id": model["name"], "object": "model", "created": int(model["modified_at"][:10].replace("-", "")) * 1000 if "modified_at" in model else 0, "owned_by": "ollama" } for model in ollama_models.get("models", []) ] } return JSONResponse(content=openai_models, status_code=200, headers=dict(response.headers)) @app.post("/v1/chat/completions") async def openai_chat_completions(request: Request, db: Session = Depends(get_db)): api_key_id = request.state.api_key_id ollama_url = crud.get_setting(db, "ollama_url", os.getenv("OLLAMA_URL", "http://localhost:11434")) default_model = crud.get_setting(db, "default_model", os.getenv("DEFAULT_MODEL", "llama3")) body = await request.json() messages = body.get("messages", []) prompt_tokens = sum(crud.count_tokens(msg.get("content", "")) for msg in messages) if not crud.check_and_increment_quota(db, api_key_id, tokens=prompt_tokens, requests=1): raise HTTPException(status_code=429, detail="Quota exceeded") ollama_body = { "model": body.get("model", default_model), "messages": messages, "stream": body.get("stream", False) } response = await proxy_request(f"{ollama_url}/api/chat", method="POST", json_data=ollama_body, headers=dict(request.headers)) response_content = response.json().get("message", {}).get("content", "") completion_tokens = crud.count_tokens(response_content) openai_response = { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": body.get("model", default_model), "choices": [{"index": 0, "message": {"role": "assistant", "content": response_content}, "finish_reason": "stop"}], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, } return JSONResponse(content=openai_response, status_code=200, headers={"Content-Type": "application/json"})