210 lines
5.9 KiB
Python
210 lines
5.9 KiB
Python
import time
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.core.database import Base
|
|
from app.modules.auth.ldap import (
|
|
_has_to_update,
|
|
_parse_entry,
|
|
_upsert_from_attrs,
|
|
ldap_authenticate,
|
|
)
|
|
from app.modules.auth.models import User
|
|
|
|
|
|
# --- Helpers ---
|
|
|
|
def _make_entry(**kwargs):
|
|
defaults = {
|
|
"sAMAccountName": "hofmannol",
|
|
"givenName": "Oliver",
|
|
"sn": "Hofmann",
|
|
"department": "EFI",
|
|
"description": "PF",
|
|
"accountExpires": None,
|
|
}
|
|
defaults.update(kwargs)
|
|
|
|
class _Attr:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
class _Entry:
|
|
def __getitem__(self, key):
|
|
return _Attr(defaults.get(key))
|
|
|
|
return _Entry()
|
|
|
|
|
|
@pytest.fixture
|
|
def db():
|
|
engine = create_engine(
|
|
"sqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
Base.metadata.create_all(bind=engine)
|
|
Session = sessionmaker(bind=engine)
|
|
session = Session()
|
|
yield session
|
|
session.close()
|
|
Base.metadata.drop_all(bind=engine)
|
|
|
|
|
|
# --- _parse_entry ---
|
|
|
|
def test_parse_entry_basic():
|
|
result = _parse_entry(_make_entry())
|
|
assert result["username"] == "hofmannol"
|
|
assert result["full_name"] == "Oliver Hofmann"
|
|
assert result["department"] == "EFI"
|
|
assert result["role"] == "PF"
|
|
assert result["account_expires"] is None
|
|
|
|
|
|
def test_parse_entry_truncates_long_fields():
|
|
result = _parse_entry(_make_entry(department="A" * 100))
|
|
assert len(result["department"]) == 64
|
|
|
|
|
|
def test_parse_entry_uses_nn_for_none_fields():
|
|
result = _parse_entry(_make_entry(department=None, description=None))
|
|
assert result["department"] == "N.N."
|
|
assert result["role"] == "N.N."
|
|
|
|
|
|
def test_parse_entry_handles_none_account_expires():
|
|
result = _parse_entry(_make_entry(accountExpires=None))
|
|
assert result["account_expires"] is None
|
|
|
|
|
|
# --- _has_to_update ---
|
|
|
|
def test_has_to_update_true_when_no_ldap_users(db):
|
|
assert _has_to_update(db, 12) is True
|
|
|
|
|
|
def test_has_to_update_false_when_recent_ldap_user(db):
|
|
user = User(
|
|
username="u", full_name="U", pw_hash=None,
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db.add(user)
|
|
db.commit()
|
|
assert _has_to_update(db, 12) is False
|
|
|
|
|
|
def test_has_to_update_true_when_old_ldap_user(db):
|
|
old_time = datetime.now(timezone.utc) - timedelta(hours=25)
|
|
user = User(username="u", full_name="U", pw_hash=None, updated_at=old_time)
|
|
db.add(user)
|
|
db.commit()
|
|
assert _has_to_update(db, 12) is True
|
|
|
|
|
|
def test_has_to_update_ignores_local_users(db):
|
|
# Local user (pw_hash set) should NOT count for sync timing
|
|
user = User(
|
|
username="admin", full_name="Admin", pw_hash="hash",
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
db.add(user)
|
|
db.commit()
|
|
assert _has_to_update(db, 12) is True
|
|
|
|
|
|
# --- _upsert_from_attrs ---
|
|
|
|
def test_upsert_creates_new_user(db):
|
|
attrs = {
|
|
"username": "newuser",
|
|
"full_name": "New User",
|
|
"department": "EFI",
|
|
"role": "ST",
|
|
"account_expires": None,
|
|
}
|
|
_upsert_from_attrs(db, attrs)
|
|
db.commit()
|
|
user = db.query(User).filter(User.username == "newuser").first()
|
|
assert user is not None
|
|
assert user.full_name == "New User"
|
|
assert user.pw_hash is None
|
|
|
|
|
|
def test_upsert_updates_existing_user(db):
|
|
user = User(username="existing", full_name="Old Name", pw_hash=None)
|
|
db.add(user)
|
|
db.commit()
|
|
attrs = {
|
|
"username": "existing",
|
|
"full_name": "New Name",
|
|
"department": "EFI",
|
|
"role": "PF",
|
|
"account_expires": None,
|
|
}
|
|
_upsert_from_attrs(db, attrs)
|
|
db.commit()
|
|
db.refresh(user)
|
|
assert user.full_name == "New Name"
|
|
assert user.role == "PF"
|
|
|
|
|
|
# --- ldap_authenticate ---
|
|
|
|
@patch("app.modules.auth.ldap.Server")
|
|
@patch("app.modules.auth.ldap.Connection")
|
|
def test_ldap_authenticate_success(mock_conn_cls, mock_server_cls):
|
|
mock_conn = MagicMock()
|
|
mock_conn_cls.return_value = mock_conn
|
|
mock_conn.extend.standard.who_am_i.return_value = "u:ADS1\\hofmannol"
|
|
mock_conn.entries = [_make_entry()]
|
|
|
|
result = ldap_authenticate("hofmannol", "password", "server.example", "ADS1")
|
|
|
|
assert result is not None
|
|
assert result["username"] == "hofmannol"
|
|
assert result["full_name"] == "Oliver Hofmann"
|
|
mock_conn.bind.assert_called_once()
|
|
mock_conn.unbind.assert_called_once()
|
|
|
|
|
|
@patch("app.modules.auth.ldap.Server")
|
|
@patch("app.modules.auth.ldap.Connection")
|
|
def test_ldap_authenticate_returns_none_on_socket_error(mock_conn_cls, mock_server_cls):
|
|
from ldap3.core.exceptions import LDAPSocketOpenError
|
|
mock_conn = MagicMock()
|
|
mock_conn_cls.return_value = mock_conn
|
|
mock_conn.bind.side_effect = LDAPSocketOpenError("unreachable")
|
|
|
|
result = ldap_authenticate("hofmannol", "pw", "server.example", "ADS1")
|
|
assert result is None
|
|
|
|
|
|
@patch("app.modules.auth.ldap.Server")
|
|
@patch("app.modules.auth.ldap.Connection")
|
|
def test_ldap_authenticate_returns_none_when_who_am_i_empty(mock_conn_cls, mock_server_cls):
|
|
mock_conn = MagicMock()
|
|
mock_conn_cls.return_value = mock_conn
|
|
mock_conn.extend.standard.who_am_i.return_value = None
|
|
|
|
result = ldap_authenticate("hofmannol", "wrong", "server.example", "ADS1")
|
|
assert result is None
|
|
mock_conn.unbind.assert_called_once()
|
|
|
|
|
|
@patch("app.modules.auth.ldap.Server")
|
|
@patch("app.modules.auth.ldap.Connection")
|
|
def test_ldap_authenticate_returns_none_when_no_entries(mock_conn_cls, mock_server_cls):
|
|
mock_conn = MagicMock()
|
|
mock_conn_cls.return_value = mock_conn
|
|
mock_conn.extend.standard.who_am_i.return_value = "u:ADS1\\hofmannol"
|
|
mock_conn.entries = []
|
|
|
|
result = ldap_authenticate("hofmannol", "pw", "server.example", "ADS1")
|
|
assert result is None
|