123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import logging
-
- from django.contrib.sessions.backends.base import (
- CreateError, SessionBase, UpdateError,
- )
- from django.core.exceptions import SuspiciousOperation
- from django.db import DatabaseError, IntegrityError, router, transaction
- from django.utils import timezone
- from django.utils.functional import cached_property
-
-
- class SessionStore(SessionBase):
- """
- Implement database session store.
- """
- def __init__(self, session_key=None):
- super().__init__(session_key)
-
- @classmethod
- def get_model_class(cls):
- # Avoids a circular import and allows importing SessionStore when
- # django.contrib.sessions is not in INSTALLED_APPS.
- from django.contrib.sessions.models import Session
- return Session
-
- @cached_property
- def model(self):
- return self.get_model_class()
-
- def _get_session_from_db(self):
- try:
- return self.model.objects.get(
- session_key=self.session_key,
- expire_date__gt=timezone.now()
- )
- except (self.model.DoesNotExist, SuspiciousOperation) as e:
- if isinstance(e, SuspiciousOperation):
- logger = logging.getLogger('django.security.%s' % e.__class__.__name__)
- logger.warning(str(e))
- self._session_key = None
-
- def load(self):
- s = self._get_session_from_db()
- return self.decode(s.session_data) if s else {}
-
- def exists(self, session_key):
- return self.model.objects.filter(session_key=session_key).exists()
-
- def create(self):
- while True:
- self._session_key = self._get_new_session_key()
- try:
- # Save immediately to ensure we have a unique entry in the
- # database.
- self.save(must_create=True)
- except CreateError:
- # Key wasn't unique. Try again.
- continue
- self.modified = True
- return
-
- def create_model_instance(self, data):
- """
- Return a new instance of the session model object, which represents the
- current session state. Intended to be used for saving the session data
- to the database.
- """
- return self.model(
- session_key=self._get_or_create_session_key(),
- session_data=self.encode(data),
- expire_date=self.get_expiry_date(),
- )
-
- def save(self, must_create=False):
- """
- Save the current session data to the database. If 'must_create' is
- True, raise a database error if the saving operation doesn't create a
- new entry (as opposed to possibly updating an existing entry).
- """
- if self.session_key is None:
- return self.create()
- data = self._get_session(no_load=must_create)
- obj = self.create_model_instance(data)
- using = router.db_for_write(self.model, instance=obj)
- try:
- with transaction.atomic(using=using):
- obj.save(force_insert=must_create, force_update=not must_create, using=using)
- except IntegrityError:
- if must_create:
- raise CreateError
- raise
- except DatabaseError:
- if not must_create:
- raise UpdateError
- raise
-
- def delete(self, session_key=None):
- if session_key is None:
- if self.session_key is None:
- return
- session_key = self.session_key
- try:
- self.model.objects.get(session_key=session_key).delete()
- except self.model.DoesNotExist:
- pass
-
- @classmethod
- def clear_expired(cls):
- cls.get_model_class().objects.filter(expire_date__lt=timezone.now()).delete()
|