|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- from datetime import datetime
-
- from django.conf import settings
- from django.db.models.expressions import Func
- from django.db.models.fields import (
- DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
- )
- from django.db.models.lookups import (
- Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
- )
- from django.utils import timezone
-
-
- class TimezoneMixin:
- tzinfo = None
-
- def get_tzname(self):
- # Timezone conversions must happen to the input datetime *before*
- # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
- # database as 2016-01-01 01:00:00 +00:00. Any results should be
- # based on the input datetime not the stored datetime.
- tzname = None
- if settings.USE_TZ:
- if self.tzinfo is None:
- tzname = timezone.get_current_timezone_name()
- else:
- tzname = timezone._get_timezone_name(self.tzinfo)
- return tzname
-
-
- class Extract(TimezoneMixin, Transform):
- lookup_name = None
- output_field = IntegerField()
-
- def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
- if self.lookup_name is None:
- self.lookup_name = lookup_name
- if self.lookup_name is None:
- raise ValueError('lookup_name must be provided')
- self.tzinfo = tzinfo
- super().__init__(expression, **extra)
-
- def as_sql(self, compiler, connection):
- sql, params = compiler.compile(self.lhs)
- lhs_output_field = self.lhs.output_field
- if isinstance(lhs_output_field, DateTimeField):
- tzname = self.get_tzname()
- sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
- elif isinstance(lhs_output_field, DateField):
- sql = connection.ops.date_extract_sql(self.lookup_name, sql)
- elif isinstance(lhs_output_field, TimeField):
- sql = connection.ops.time_extract_sql(self.lookup_name, sql)
- elif isinstance(lhs_output_field, DurationField):
- if not connection.features.has_native_duration_field:
- raise ValueError('Extract requires native DurationField database support.')
- sql = connection.ops.time_extract_sql(self.lookup_name, sql)
- else:
- # resolve_expression has already validated the output_field so this
- # assert should never be hit.
- assert False, "Tried to Extract from an invalid type."
- return sql, params
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
- field = copy.lhs.output_field
- if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
- raise ValueError(
- 'Extract input expression must be DateField, DateTimeField, '
- 'TimeField, or DurationField.'
- )
- # Passing dates to functions expecting datetimes is most likely a mistake.
- if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
- raise ValueError(
- "Cannot extract time component '%s' from DateField '%s'. " % (copy.lookup_name, field.name)
- )
- return copy
-
-
- class ExtractYear(Extract):
- lookup_name = 'year'
-
-
- class ExtractIsoYear(Extract):
- """Return the ISO-8601 week-numbering year."""
- lookup_name = 'iso_year'
-
-
- class ExtractMonth(Extract):
- lookup_name = 'month'
-
-
- class ExtractDay(Extract):
- lookup_name = 'day'
-
-
- class ExtractWeek(Extract):
- """
- Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
- week.
- """
- lookup_name = 'week'
-
-
- class ExtractWeekDay(Extract):
- """
- Return Sunday=1 through Saturday=7.
-
- To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
- """
- lookup_name = 'week_day'
-
-
- class ExtractQuarter(Extract):
- lookup_name = 'quarter'
-
-
- class ExtractHour(Extract):
- lookup_name = 'hour'
-
-
- class ExtractMinute(Extract):
- lookup_name = 'minute'
-
-
- class ExtractSecond(Extract):
- lookup_name = 'second'
-
-
- DateField.register_lookup(ExtractYear)
- DateField.register_lookup(ExtractMonth)
- DateField.register_lookup(ExtractDay)
- DateField.register_lookup(ExtractWeekDay)
- DateField.register_lookup(ExtractWeek)
- DateField.register_lookup(ExtractIsoYear)
- DateField.register_lookup(ExtractQuarter)
-
- TimeField.register_lookup(ExtractHour)
- TimeField.register_lookup(ExtractMinute)
- TimeField.register_lookup(ExtractSecond)
-
- DateTimeField.register_lookup(ExtractHour)
- DateTimeField.register_lookup(ExtractMinute)
- DateTimeField.register_lookup(ExtractSecond)
-
- ExtractYear.register_lookup(YearExact)
- ExtractYear.register_lookup(YearGt)
- ExtractYear.register_lookup(YearGte)
- ExtractYear.register_lookup(YearLt)
- ExtractYear.register_lookup(YearLte)
-
- ExtractIsoYear.register_lookup(YearExact)
- ExtractIsoYear.register_lookup(YearGt)
- ExtractIsoYear.register_lookup(YearGte)
- ExtractIsoYear.register_lookup(YearLt)
- ExtractIsoYear.register_lookup(YearLte)
-
-
- class Now(Func):
- template = 'CURRENT_TIMESTAMP'
- output_field = DateTimeField()
-
- def as_postgresql(self, compiler, connection, **extra_context):
- # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
- # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
- # other databases.
- return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
-
-
- class TruncBase(TimezoneMixin, Transform):
- kind = None
- tzinfo = None
-
- def __init__(self, expression, output_field=None, tzinfo=None, **extra):
- self.tzinfo = tzinfo
- super().__init__(expression, output_field=output_field, **extra)
-
- def as_sql(self, compiler, connection):
- inner_sql, inner_params = compiler.compile(self.lhs)
- if isinstance(self.output_field, DateTimeField):
- tzname = self.get_tzname()
- sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
- elif isinstance(self.output_field, DateField):
- sql = connection.ops.date_trunc_sql(self.kind, inner_sql)
- elif isinstance(self.output_field, TimeField):
- sql = connection.ops.time_trunc_sql(self.kind, inner_sql)
- else:
- raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
- return sql, inner_params
-
- def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
- copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
- field = copy.lhs.output_field
- # DateTimeField is a subclass of DateField so this works for both.
- assert isinstance(field, (DateField, TimeField)), (
- "%r isn't a DateField, TimeField, or DateTimeField." % field.name
- )
- # If self.output_field was None, then accessing the field will trigger
- # the resolver to assign it to self.lhs.output_field.
- if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
- raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
- # Passing dates or times to functions expecting datetimes is most
- # likely a mistake.
- class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
- output_field = class_output_field or copy.output_field
- has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
- if type(field) == DateField and (
- isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
- raise ValueError("Cannot truncate DateField '%s' to %s. " % (
- field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
- ))
- elif isinstance(field, TimeField) and (
- isinstance(output_field, DateTimeField) or
- copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
- raise ValueError("Cannot truncate TimeField '%s' to %s. " % (
- field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
- ))
- return copy
-
- def convert_value(self, value, expression, connection):
- if isinstance(self.output_field, DateTimeField):
- if not settings.USE_TZ:
- pass
- elif value is not None:
- value = value.replace(tzinfo=None)
- value = timezone.make_aware(value, self.tzinfo)
- elif not connection.features.has_zoneinfo_database:
- raise ValueError(
- 'Database returned an invalid datetime value. Are time '
- 'zone definitions for your database installed?'
- )
- elif isinstance(value, datetime):
- if value is None:
- pass
- elif isinstance(self.output_field, DateField):
- value = value.date()
- elif isinstance(self.output_field, TimeField):
- value = value.time()
- return value
-
-
- class Trunc(TruncBase):
-
- def __init__(self, expression, kind, output_field=None, tzinfo=None, **extra):
- self.kind = kind
- super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
-
-
- class TruncYear(TruncBase):
- kind = 'year'
-
-
- class TruncQuarter(TruncBase):
- kind = 'quarter'
-
-
- class TruncMonth(TruncBase):
- kind = 'month'
-
-
- class TruncWeek(TruncBase):
- """Truncate to midnight on the Monday of the week."""
- kind = 'week'
-
-
- class TruncDay(TruncBase):
- kind = 'day'
-
-
- class TruncDate(TruncBase):
- kind = 'date'
- lookup_name = 'date'
- output_field = DateField()
-
- def as_sql(self, compiler, connection):
- # Cast to date rather than truncate to date.
- lhs, lhs_params = compiler.compile(self.lhs)
- tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
- sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
- return sql, lhs_params
-
-
- class TruncTime(TruncBase):
- kind = 'time'
- lookup_name = 'time'
- output_field = TimeField()
-
- def as_sql(self, compiler, connection):
- # Cast to time rather than truncate to time.
- lhs, lhs_params = compiler.compile(self.lhs)
- tzname = timezone.get_current_timezone_name() if settings.USE_TZ else None
- sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
- return sql, lhs_params
-
-
- class TruncHour(TruncBase):
- kind = 'hour'
-
-
- class TruncMinute(TruncBase):
- kind = 'minute'
-
-
- class TruncSecond(TruncBase):
- kind = 'second'
-
-
- DateTimeField.register_lookup(TruncDate)
- DateTimeField.register_lookup(TruncTime)
|