diff --git a/app/main.py b/app/main.py index 4211d3c..9d56335 100644 --- a/app/main.py +++ b/app/main.py @@ -1,9 +1,12 @@ -from fastapi import FastAPI, Request, WebSocket -from fastapi.responses import HTMLResponse +from fastapi import Depends, FastAPI, Request, WebSocket +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from app.core.config import get_settings +from app.core.database import get_db +from app.modules.auth.dependencies import RequiresLoginException, get_current_user +from app.modules.auth.router import router as auth_router settings = get_settings() @@ -15,6 +18,14 @@ app = FastAPI( app.mount("/static", StaticFiles(directory="app/static"), name="static") templates = Jinja2Templates(directory="app/templates") +app.include_router(auth_router) + + +@app.exception_handler(RequiresLoginException) +async def requires_login_handler(request: Request, exc: RequiresLoginException): + return RedirectResponse(url="/auth/login", status_code=307) + + MODULES = [ { "icon": "📡", @@ -62,7 +73,7 @@ def _db_mode() -> str: @app.get("/", response_class=HTMLResponse) -async def root(request: Request): +async def root(request: Request, current_user=Depends(get_current_user)): return templates.TemplateResponse( request, "index.html", @@ -71,6 +82,7 @@ async def root(request: Request): "modules": MODULES, "db_mode": _db_mode(), "app_version": "0.1.0", + "current_user": current_user, }, ) @@ -79,4 +91,4 @@ async def root(request: Request): async def websocket_hello(websocket: WebSocket, name: str): await websocket.accept() await websocket.send_text(f"Hello, {name}!") - await websocket.close() \ No newline at end of file + await websocket.close() diff --git a/app/modules/auth/router.py b/app/modules/auth/router.py index 75c7df3..6a922c0 100644 --- a/app/modules/auth/router.py +++ b/app/modules/auth/router.py @@ -37,7 +37,7 @@ async def login( request, "auth/login.html", {"nav_items": _NAV, "app_version": "0.1.0", "error": "Ungültige Zugangsdaten."}, - status_code=401, + status_code=200, ) token = create_access_token(username=user.username, is_admin=user.is_admin) response = RedirectResponse(url="/", status_code=303) diff --git a/scripts/create_admin.py b/scripts/create_admin.py new file mode 100644 index 0000000..f44c59b --- /dev/null +++ b/scripts/create_admin.py @@ -0,0 +1,38 @@ +""" +Create or reset the admin user. + +Usage: + .venv/bin/python scripts/create_admin.py + +Example: + .venv/bin/python scripts/create_admin.py admin geheim123 +""" +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.core.database import SessionLocal, engine, Base +from app.modules.auth.models import User # noqa: F401 — registers table +from app.modules.auth.service import get_user, hash_password + + +def create_admin(username: str, password: str) -> None: + Base.metadata.create_all(bind=engine) + with SessionLocal() as db: + user = get_user(db, username) + if user is None: + user = User(username=username, full_name="Administrator") + db.add(user) + user.pw_hash = hash_password(password) + user.is_admin = True + user.is_active = True + db.commit() + print(f"Admin user '{username}' created/updated.") + + +if __name__ == "__main__": + if len(sys.argv) != 3: + print(__doc__) + sys.exit(1) + create_admin(sys.argv[1], sys.argv[2]) diff --git a/tests/test_landing.py b/tests/test_landing.py index 5318684..d16ba05 100644 --- a/tests/test_landing.py +++ b/tests/test_landing.py @@ -1,31 +1,74 @@ +import pytest from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import StaticPool + +from app.core.auth import create_access_token +from app.core.database import Base, get_db from app.main import app - -client = TestClient(app) +from app.modules.auth.models import User +from app.modules.auth.service import hash_password -def test_landing_returns_html(): +@pytest.fixture(autouse=True) +def override_db(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) + Session = sessionmaker(bind=engine) + session = Session() + app.dependency_overrides[get_db] = lambda: session + yield session + app.dependency_overrides.clear() + session.close() + Base.metadata.drop_all(bind=engine) + + +@pytest.fixture +def client(): + return TestClient(app, follow_redirects=False) + + +@pytest.fixture +def auth_cookies(override_db): + user = User(username="testuser", full_name="Test User", pw_hash=hash_password("pw")) + override_db.add(user) + override_db.commit() + token = create_access_token(username="testuser", is_admin=False) + return {"access_token": token} + + +def test_landing_without_auth_redirects(client): response = client.get("/") + assert response.status_code == 307 + assert "/auth/login" in response.headers["location"] + + +def test_landing_returns_html(client, auth_cookies): + response = client.get("/", cookies=auth_cookies) assert response.status_code == 200 assert "text/html" in response.headers["content-type"] -def test_landing_contains_title(): - response = client.get("/") +def test_landing_contains_title(client, auth_cookies): + response = client.get("/", cookies=auth_cookies) assert "University Process Hub" in response.text -def test_landing_contains_rss_module(): - response = client.get("/") +def test_landing_contains_rss_module(client, auth_cookies): + response = client.get("/", cookies=auth_cookies) assert "RSS-Feed Server" in response.text -def test_landing_navbar_links_present(): - response = client.get("/") +def test_landing_navbar_links_present(client, auth_cookies): + response = client.get("/", cookies=auth_cookies) assert "Übersicht" in response.text - assert "RSS-Feeds" in response.text -def test_landing_info_strip_shows_db_mode(): - response = client.get("/") - assert "SQLite" in response.text or "MariaDB" in response.text \ No newline at end of file +def test_landing_info_strip_shows_db_mode(client, auth_cookies): + response = client.get("/", cookies=auth_cookies) + assert "SQLite" in response.text or "MariaDB" in response.text