feat: add LDAP functions (authenticate, sync, parse, upsert)
This commit is contained in:
parent
c9d8273680
commit
c92351786a
181
app/modules/auth/ldap.py
Normal file
181
app/modules/auth/ldap.py
Normal file
@ -0,0 +1,181 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from ldap3 import ALL, ALL_ATTRIBUTES, NTLM, Connection, Server
|
||||
from ldap3.core.exceptions import LDAPSocketOpenError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.modules.auth.models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def ldap_authenticate(
|
||||
username: str,
|
||||
password: str,
|
||||
ldap_server: str,
|
||||
ldap_domain: str,
|
||||
) -> Optional[dict]:
|
||||
"""NTLM bind + fetch user attributes. Returns attrs dict on success, None on failure."""
|
||||
server = Server(ldap_server, get_info=ALL, connect_timeout=8)
|
||||
conn = Connection(
|
||||
server,
|
||||
user=f"{ldap_domain}\\{username}",
|
||||
password=password,
|
||||
authentication=NTLM,
|
||||
)
|
||||
try:
|
||||
conn.bind()
|
||||
except LDAPSocketOpenError:
|
||||
logger.warning("LDAP server %s not reachable", ldap_server)
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning("LDAP bind failed for user %s", username, exc_info=True)
|
||||
return None
|
||||
|
||||
if not conn.extend.standard.who_am_i():
|
||||
conn.unbind()
|
||||
return None
|
||||
|
||||
conn.search(
|
||||
search_base=f"DC={ldap_domain},DC=fh-nuernberg,DC=de",
|
||||
search_filter=f"(&(objectclass=user)(CN={username}))",
|
||||
attributes=ALL_ATTRIBUTES,
|
||||
)
|
||||
|
||||
if not conn.entries:
|
||||
conn.unbind()
|
||||
logger.warning("LDAP: user %s authenticated but no entry found", username)
|
||||
return None
|
||||
|
||||
attrs = _parse_entry(conn.entries[0])
|
||||
conn.unbind()
|
||||
return attrs
|
||||
|
||||
|
||||
def _parse_entry(entry) -> dict:
|
||||
def val(field: str, default: str = "N.N.") -> str:
|
||||
try:
|
||||
v = entry[field].value
|
||||
return str(v)[:64] if v else default
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
def expire_date(field: str) -> Optional[datetime]:
|
||||
try:
|
||||
v = entry[field].value
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, datetime):
|
||||
dt = v
|
||||
else:
|
||||
dt = datetime(v.year, v.month, v.day, tzinfo=timezone.utc)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return None if dt.year >= 2099 else dt
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return {
|
||||
"username": val("sAMAccountName"),
|
||||
"full_name": (f"{val('givenName', '')} {val('sn', '')}".strip() or "N.N."),
|
||||
"department": val("department"),
|
||||
"role": val("description"),
|
||||
"account_expires": expire_date("accountExpires"),
|
||||
}
|
||||
|
||||
|
||||
def _has_to_update(db: Session, min_interval_hours: int) -> bool:
|
||||
"""True if a background sync is due."""
|
||||
deadline = datetime.now(timezone.utc) - timedelta(hours=min_interval_hours)
|
||||
latest = (
|
||||
db.query(User)
|
||||
.filter(User.pw_hash.is_(None))
|
||||
.order_by(User.updated_at.desc())
|
||||
.first()
|
||||
)
|
||||
if latest is None:
|
||||
return True
|
||||
last = latest.updated_at
|
||||
if last.tzinfo is None:
|
||||
last = last.replace(tzinfo=timezone.utc)
|
||||
return last < deadline
|
||||
|
||||
|
||||
def _upsert_from_attrs(db: Session, attrs: dict) -> None:
|
||||
"""Create or update a User from AD attribute dict (no commit)."""
|
||||
user = db.query(User).filter(User.username == attrs["username"]).first()
|
||||
if user is None:
|
||||
user = User(username=attrs["username"])
|
||||
db.add(user)
|
||||
user.full_name = attrs["full_name"]
|
||||
user.department = attrs["department"]
|
||||
user.role = attrs["role"]
|
||||
user.account_expires = attrs["account_expires"]
|
||||
|
||||
|
||||
def sync_all_users(
|
||||
username: str,
|
||||
password: str,
|
||||
ldap_server: str,
|
||||
ldap_domain: str,
|
||||
ldap_search_base: str,
|
||||
min_interval_hours: int,
|
||||
letter_delay_seconds: float,
|
||||
) -> None:
|
||||
"""Login-triggered background sync of all AD users. Creates its own DB session."""
|
||||
from app.core.database import SessionLocal
|
||||
|
||||
with SessionLocal() as db:
|
||||
if not _has_to_update(db, min_interval_hours):
|
||||
logger.info("LDAP sync skipped: last sync is recent")
|
||||
return
|
||||
|
||||
server = Server(ldap_server, get_info=ALL, connect_timeout=8)
|
||||
conn = Connection(
|
||||
server,
|
||||
user=f"{ldap_domain}\\{username}",
|
||||
password=password,
|
||||
authentication=NTLM,
|
||||
)
|
||||
try:
|
||||
conn.bind()
|
||||
except Exception:
|
||||
logger.warning("LDAP sync: bind failed", exc_info=True)
|
||||
return
|
||||
|
||||
found_usernames: set[str] = set()
|
||||
chars = "abcdefghijklmnopqrstuvwxyz_"
|
||||
|
||||
for char in chars:
|
||||
try:
|
||||
conn.search(
|
||||
search_base=ldap_search_base,
|
||||
search_filter=f"(&(objectclass=user)(CN={char}*))",
|
||||
attributes=ALL_ATTRIBUTES,
|
||||
)
|
||||
for entry in conn.entries:
|
||||
try:
|
||||
attrs = _parse_entry(entry)
|
||||
found_usernames.add(attrs["username"])
|
||||
_upsert_from_attrs(db, attrs)
|
||||
except Exception:
|
||||
logger.warning("LDAP sync: failed to parse entry", exc_info=True)
|
||||
db.commit()
|
||||
except Exception:
|
||||
logger.warning("LDAP sync: search failed for prefix '%s'", char, exc_info=True)
|
||||
|
||||
if letter_delay_seconds > 0:
|
||||
time.sleep(letter_delay_seconds)
|
||||
|
||||
# Deactivate LDAP users no longer in AD
|
||||
ldap_users = db.query(User).filter(User.pw_hash.is_(None)).all()
|
||||
for user_obj in ldap_users:
|
||||
if user_obj.username not in found_usernames:
|
||||
logger.info("LDAP sync: deactivating %s (not in AD)", user_obj.username)
|
||||
user_obj.is_active = False
|
||||
db.commit()
|
||||
conn.unbind()
|
||||
logger.info("LDAP sync complete — %d users found", len(found_usernames))
|
||||
209
tests/test_ldap.py
Normal file
209
tests/test_ldap.py
Normal file
@ -0,0 +1,209 @@
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.core.database import Base
|
||||
from app.modules.auth.ldap import (
|
||||
_has_to_update,
|
||||
_parse_entry,
|
||||
_upsert_from_attrs,
|
||||
ldap_authenticate,
|
||||
)
|
||||
from app.modules.auth.models import User
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
def _make_entry(**kwargs):
|
||||
defaults = {
|
||||
"sAMAccountName": "hofmannol",
|
||||
"givenName": "Oliver",
|
||||
"sn": "Hofmann",
|
||||
"department": "EFI",
|
||||
"description": "PF",
|
||||
"accountExpires": None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
|
||||
class _Attr:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
class _Entry:
|
||||
def __getitem__(self, key):
|
||||
return _Attr(defaults.get(key))
|
||||
|
||||
return _Entry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
engine = create_engine(
|
||||
"sqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
yield session
|
||||
session.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
|
||||
# --- _parse_entry ---
|
||||
|
||||
def test_parse_entry_basic():
|
||||
result = _parse_entry(_make_entry())
|
||||
assert result["username"] == "hofmannol"
|
||||
assert result["full_name"] == "Oliver Hofmann"
|
||||
assert result["department"] == "EFI"
|
||||
assert result["role"] == "PF"
|
||||
assert result["account_expires"] is None
|
||||
|
||||
|
||||
def test_parse_entry_truncates_long_fields():
|
||||
result = _parse_entry(_make_entry(department="A" * 100))
|
||||
assert len(result["department"]) == 64
|
||||
|
||||
|
||||
def test_parse_entry_uses_nn_for_none_fields():
|
||||
result = _parse_entry(_make_entry(department=None, description=None))
|
||||
assert result["department"] == "N.N."
|
||||
assert result["role"] == "N.N."
|
||||
|
||||
|
||||
def test_parse_entry_handles_none_account_expires():
|
||||
result = _parse_entry(_make_entry(accountExpires=None))
|
||||
assert result["account_expires"] is None
|
||||
|
||||
|
||||
# --- _has_to_update ---
|
||||
|
||||
def test_has_to_update_true_when_no_ldap_users(db):
|
||||
assert _has_to_update(db, 12) is True
|
||||
|
||||
|
||||
def test_has_to_update_false_when_recent_ldap_user(db):
|
||||
user = User(
|
||||
username="u", full_name="U", pw_hash=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
assert _has_to_update(db, 12) is False
|
||||
|
||||
|
||||
def test_has_to_update_true_when_old_ldap_user(db):
|
||||
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
|
||||
user = User(username="u", full_name="U", pw_hash=None, updated_at=old_time)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
assert _has_to_update(db, 12) is True
|
||||
|
||||
|
||||
def test_has_to_update_ignores_local_users(db):
|
||||
# Local user (pw_hash set) should NOT count for sync timing
|
||||
user = User(
|
||||
username="admin", full_name="Admin", pw_hash="hash",
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
assert _has_to_update(db, 12) is True
|
||||
|
||||
|
||||
# --- _upsert_from_attrs ---
|
||||
|
||||
def test_upsert_creates_new_user(db):
|
||||
attrs = {
|
||||
"username": "newuser",
|
||||
"full_name": "New User",
|
||||
"department": "EFI",
|
||||
"role": "ST",
|
||||
"account_expires": None,
|
||||
}
|
||||
_upsert_from_attrs(db, attrs)
|
||||
db.commit()
|
||||
user = db.query(User).filter(User.username == "newuser").first()
|
||||
assert user is not None
|
||||
assert user.full_name == "New User"
|
||||
assert user.pw_hash is None
|
||||
|
||||
|
||||
def test_upsert_updates_existing_user(db):
|
||||
user = User(username="existing", full_name="Old Name", pw_hash=None)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
attrs = {
|
||||
"username": "existing",
|
||||
"full_name": "New Name",
|
||||
"department": "EFI",
|
||||
"role": "PF",
|
||||
"account_expires": None,
|
||||
}
|
||||
_upsert_from_attrs(db, attrs)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
assert user.full_name == "New Name"
|
||||
assert user.role == "PF"
|
||||
|
||||
|
||||
# --- ldap_authenticate ---
|
||||
|
||||
@patch("app.modules.auth.ldap.Server")
|
||||
@patch("app.modules.auth.ldap.Connection")
|
||||
def test_ldap_authenticate_success(mock_conn_cls, mock_server_cls):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn_cls.return_value = mock_conn
|
||||
mock_conn.extend.standard.who_am_i.return_value = "u:ADS1\\hofmannol"
|
||||
mock_conn.entries = [_make_entry()]
|
||||
|
||||
result = ldap_authenticate("hofmannol", "password", "server.example", "ADS1")
|
||||
|
||||
assert result is not None
|
||||
assert result["username"] == "hofmannol"
|
||||
assert result["full_name"] == "Oliver Hofmann"
|
||||
mock_conn.bind.assert_called_once()
|
||||
mock_conn.unbind.assert_called_once()
|
||||
|
||||
|
||||
@patch("app.modules.auth.ldap.Server")
|
||||
@patch("app.modules.auth.ldap.Connection")
|
||||
def test_ldap_authenticate_returns_none_on_socket_error(mock_conn_cls, mock_server_cls):
|
||||
from ldap3.core.exceptions import LDAPSocketOpenError
|
||||
mock_conn = MagicMock()
|
||||
mock_conn_cls.return_value = mock_conn
|
||||
mock_conn.bind.side_effect = LDAPSocketOpenError("unreachable")
|
||||
|
||||
result = ldap_authenticate("hofmannol", "pw", "server.example", "ADS1")
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("app.modules.auth.ldap.Server")
|
||||
@patch("app.modules.auth.ldap.Connection")
|
||||
def test_ldap_authenticate_returns_none_when_who_am_i_empty(mock_conn_cls, mock_server_cls):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn_cls.return_value = mock_conn
|
||||
mock_conn.extend.standard.who_am_i.return_value = None
|
||||
|
||||
result = ldap_authenticate("hofmannol", "wrong", "server.example", "ADS1")
|
||||
assert result is None
|
||||
mock_conn.unbind.assert_called_once()
|
||||
|
||||
|
||||
@patch("app.modules.auth.ldap.Server")
|
||||
@patch("app.modules.auth.ldap.Connection")
|
||||
def test_ldap_authenticate_returns_none_when_no_entries(mock_conn_cls, mock_server_cls):
|
||||
mock_conn = MagicMock()
|
||||
mock_conn_cls.return_value = mock_conn
|
||||
mock_conn.extend.standard.who_am_i.return_value = "u:ADS1\\hofmannol"
|
||||
mock_conn.entries = []
|
||||
|
||||
result = ldap_authenticate("hofmannol", "pw", "server.example", "ADS1")
|
||||
assert result is None
|
||||
Loading…
x
Reference in New Issue
Block a user