diff --git a/app/core/database.py b/app/core/database.py new file mode 100644 index 0000000..b4b1204 --- /dev/null +++ b/app/core/database.py @@ -0,0 +1,22 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker + +from app.core.config import get_settings + +settings = get_settings() + +_connect_args = {"check_same_thread": False} if settings.DATABASE_URL.startswith("sqlite") else {} +engine = create_engine(settings.DATABASE_URL, connect_args=_connect_args) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +class Base(DeclarativeBase): + pass + + +def get_db() -> Session: + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..e8aaf00 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,16 @@ +from sqlalchemy.orm import Session +from app.core.database import Base, get_db + + +def test_get_db_yields_session(): + gen = get_db() + db = next(gen) + assert isinstance(db, Session) + try: + next(gen) + except StopIteration: + pass + + +def test_base_has_metadata(): + assert Base.metadata is not None