|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- import datetime
- import decimal
- import uuid
- from functools import lru_cache
- from itertools import chain
-
- from django.conf import settings
- from django.core.exceptions import FieldError
- from django.db import utils
- from django.db.backends.base.operations import BaseDatabaseOperations
- from django.db.models import aggregates, fields
- from django.db.models.expressions import Col
- from django.utils import timezone
- from django.utils.dateparse import parse_date, parse_datetime, parse_time
- from django.utils.duration import duration_microseconds
- from django.utils.functional import cached_property
-
-
- class DatabaseOperations(BaseDatabaseOperations):
- cast_char_field_without_max_length = 'text'
- cast_data_types = {
- 'DateField': 'TEXT',
- 'DateTimeField': 'TEXT',
- }
- explain_prefix = 'EXPLAIN QUERY PLAN'
-
- def bulk_batch_size(self, fields, objs):
- """
- SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
- 999 variables per query.
-
- If there's only a single field to insert, the limit is 500
- (SQLITE_MAX_COMPOUND_SELECT).
- """
- if len(fields) == 1:
- return 500
- elif len(fields) > 1:
- return self.connection.features.max_query_params // len(fields)
- else:
- return len(objs)
-
- def check_expression_support(self, expression):
- bad_fields = (fields.DateField, fields.DateTimeField, fields.TimeField)
- bad_aggregates = (aggregates.Sum, aggregates.Avg, aggregates.Variance, aggregates.StdDev)
- if isinstance(expression, bad_aggregates):
- for expr in expression.get_source_expressions():
- try:
- output_field = expr.output_field
- except FieldError:
- # Not every subexpression has an output_field which is fine
- # to ignore.
- pass
- else:
- if isinstance(output_field, bad_fields):
- raise utils.NotSupportedError(
- 'You cannot use Sum, Avg, StdDev, and Variance '
- 'aggregations on date/time fields in sqlite3 '
- 'since date/time is saved as text.'
- )
- if isinstance(expression, aggregates.Aggregate) and len(expression.source_expressions) > 1:
- raise utils.NotSupportedError(
- "SQLite doesn't support DISTINCT on aggregate functions "
- "accepting multiple arguments."
- )
-
- def date_extract_sql(self, lookup_type, field_name):
- """
- Support EXTRACT with a user-defined function django_date_extract()
- that's registered in connect(). Use single quotes because this is a
- string and could otherwise cause a collision with a field name.
- """
- return "django_date_extract('%s', %s)" % (lookup_type.lower(), field_name)
-
- def date_interval_sql(self, timedelta):
- return str(duration_microseconds(timedelta))
-
- def format_for_duration_arithmetic(self, sql):
- """Do nothing since formatting is handled in the custom function."""
- return sql
-
- def date_trunc_sql(self, lookup_type, field_name):
- return "django_date_trunc('%s', %s)" % (lookup_type.lower(), field_name)
-
- def time_trunc_sql(self, lookup_type, field_name):
- return "django_time_trunc('%s', %s)" % (lookup_type.lower(), field_name)
-
- def _convert_tzname_to_sql(self, tzname):
- return "'%s'" % tzname if settings.USE_TZ else 'NULL'
-
- def datetime_cast_date_sql(self, field_name, tzname):
- return "django_datetime_cast_date(%s, %s)" % (
- field_name, self._convert_tzname_to_sql(tzname),
- )
-
- def datetime_cast_time_sql(self, field_name, tzname):
- return "django_datetime_cast_time(%s, %s)" % (
- field_name, self._convert_tzname_to_sql(tzname),
- )
-
- def datetime_extract_sql(self, lookup_type, field_name, tzname):
- return "django_datetime_extract('%s', %s, %s)" % (
- lookup_type.lower(), field_name, self._convert_tzname_to_sql(tzname),
- )
-
- def datetime_trunc_sql(self, lookup_type, field_name, tzname):
- return "django_datetime_trunc('%s', %s, %s)" % (
- lookup_type.lower(), field_name, self._convert_tzname_to_sql(tzname),
- )
-
- def time_extract_sql(self, lookup_type, field_name):
- return "django_time_extract('%s', %s)" % (lookup_type.lower(), field_name)
-
- def pk_default_value(self):
- return "NULL"
-
- def _quote_params_for_last_executed_query(self, params):
- """
- Only for last_executed_query! Don't use this to execute SQL queries!
- """
- # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
- # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
- # number of return values, default = 2000). Since Python's sqlite3
- # module doesn't expose the get_limit() C API, assume the default
- # limits are in effect and split the work in batches if needed.
- BATCH_SIZE = 999
- if len(params) > BATCH_SIZE:
- results = ()
- for index in range(0, len(params), BATCH_SIZE):
- chunk = params[index:index + BATCH_SIZE]
- results += self._quote_params_for_last_executed_query(chunk)
- return results
-
- sql = 'SELECT ' + ', '.join(['QUOTE(?)'] * len(params))
- # Bypass Django's wrappers and use the underlying sqlite3 connection
- # to avoid logging this query - it would trigger infinite recursion.
- cursor = self.connection.connection.cursor()
- # Native sqlite3 cursors cannot be used as context managers.
- try:
- return cursor.execute(sql, params).fetchone()
- finally:
- cursor.close()
-
- def last_executed_query(self, cursor, sql, params):
- # Python substitutes parameters in Modules/_sqlite/cursor.c with:
- # pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars);
- # Unfortunately there is no way to reach self->statement from Python,
- # so we quote and substitute parameters manually.
- if params:
- if isinstance(params, (list, tuple)):
- params = self._quote_params_for_last_executed_query(params)
- else:
- values = tuple(params.values())
- values = self._quote_params_for_last_executed_query(values)
- params = dict(zip(params, values))
- return sql % params
- # For consistency with SQLiteCursorWrapper.execute(), just return sql
- # when there are no parameters. See #13648 and #17158.
- else:
- return sql
-
- def quote_name(self, name):
- if name.startswith('"') and name.endswith('"'):
- return name # Quoting once is enough.
- return '"%s"' % name
-
- def no_limit_value(self):
- return -1
-
- def __references_graph(self, table_name):
- query = """
- WITH tables AS (
- SELECT %s name
- UNION
- SELECT sqlite_master.name
- FROM sqlite_master
- JOIN tables ON (sql REGEXP %s || tables.name || %s)
- ) SELECT name FROM tables;
- """
- params = (
- table_name,
- r'(?i)\s+references\s+("|\')?',
- r'("|\')?\s*\(',
- )
- with self.connection.cursor() as cursor:
- results = cursor.execute(query, params)
- return [row[0] for row in results.fetchall()]
-
- @cached_property
- def _references_graph(self):
- # 512 is large enough to fit the ~330 tables (as of this writing) in
- # Django's test suite.
- return lru_cache(maxsize=512)(self.__references_graph)
-
- def sql_flush(self, style, tables, sequences, allow_cascade=False):
- if tables and allow_cascade:
- # Simulate TRUNCATE CASCADE by recursively collecting the tables
- # referencing the tables to be flushed.
- tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
- sql = ['%s %s %s;' % (
- style.SQL_KEYWORD('DELETE'),
- style.SQL_KEYWORD('FROM'),
- style.SQL_FIELD(self.quote_name(table))
- ) for table in tables]
- # Note: No requirement for reset of auto-incremented indices (cf. other
- # sql_flush() implementations). Just return SQL at this point
- return sql
-
- def adapt_datetimefield_value(self, value):
- if value is None:
- return None
-
- # Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
- return value
-
- # SQLite doesn't support tz-aware datetimes
- if timezone.is_aware(value):
- if settings.USE_TZ:
- value = timezone.make_naive(value, self.connection.timezone)
- else:
- raise ValueError("SQLite backend does not support timezone-aware datetimes when USE_TZ is False.")
-
- return str(value)
-
- def adapt_timefield_value(self, value):
- if value is None:
- return None
-
- # Expression values are adapted by the database.
- if hasattr(value, 'resolve_expression'):
- return value
-
- # SQLite doesn't support tz-aware datetimes
- if timezone.is_aware(value):
- raise ValueError("SQLite backend does not support timezone-aware times.")
-
- return str(value)
-
- def get_db_converters(self, expression):
- converters = super().get_db_converters(expression)
- internal_type = expression.output_field.get_internal_type()
- if internal_type == 'DateTimeField':
- converters.append(self.convert_datetimefield_value)
- elif internal_type == 'DateField':
- converters.append(self.convert_datefield_value)
- elif internal_type == 'TimeField':
- converters.append(self.convert_timefield_value)
- elif internal_type == 'DecimalField':
- converters.append(self.get_decimalfield_converter(expression))
- elif internal_type == 'UUIDField':
- converters.append(self.convert_uuidfield_value)
- elif internal_type in ('NullBooleanField', 'BooleanField'):
- converters.append(self.convert_booleanfield_value)
- return converters
-
- def convert_datetimefield_value(self, value, expression, connection):
- if value is not None:
- if not isinstance(value, datetime.datetime):
- value = parse_datetime(value)
- if settings.USE_TZ and not timezone.is_aware(value):
- value = timezone.make_aware(value, self.connection.timezone)
- return value
-
- def convert_datefield_value(self, value, expression, connection):
- if value is not None:
- if not isinstance(value, datetime.date):
- value = parse_date(value)
- return value
-
- def convert_timefield_value(self, value, expression, connection):
- if value is not None:
- if not isinstance(value, datetime.time):
- value = parse_time(value)
- return value
-
- def get_decimalfield_converter(self, expression):
- # SQLite stores only 15 significant digits. Digits coming from
- # float inaccuracy must be removed.
- create_decimal = decimal.Context(prec=15).create_decimal_from_float
- if isinstance(expression, Col):
- quantize_value = decimal.Decimal(1).scaleb(-expression.output_field.decimal_places)
-
- def converter(value, expression, connection):
- if value is not None:
- return create_decimal(value).quantize(quantize_value, context=expression.output_field.context)
- else:
- def converter(value, expression, connection):
- if value is not None:
- return create_decimal(value)
- return converter
-
- def convert_uuidfield_value(self, value, expression, connection):
- if value is not None:
- value = uuid.UUID(value)
- return value
-
- def convert_booleanfield_value(self, value, expression, connection):
- return bool(value) if value in (1, 0) else value
-
- def bulk_insert_sql(self, fields, placeholder_rows):
- return " UNION ALL ".join(
- "SELECT %s" % ", ".join(row)
- for row in placeholder_rows
- )
-
- def combine_expression(self, connector, sub_expressions):
- # SQLite doesn't have a ^ operator, so use the user-defined POWER
- # function that's registered in connect().
- if connector == '^':
- return 'POWER(%s)' % ','.join(sub_expressions)
- return super().combine_expression(connector, sub_expressions)
-
- def combine_duration_expression(self, connector, sub_expressions):
- if connector not in ['+', '-']:
- raise utils.DatabaseError('Invalid connector for timedelta: %s.' % connector)
- fn_params = ["'%s'" % connector] + sub_expressions
- if len(fn_params) > 3:
- raise ValueError('Too many params for timedelta operations.')
- return "django_format_dtdelta(%s)" % ', '.join(fn_params)
-
- def integer_field_range(self, internal_type):
- # SQLite doesn't enforce any integer constraints
- return (None, None)
-
- def subtract_temporals(self, internal_type, lhs, rhs):
- lhs_sql, lhs_params = lhs
- rhs_sql, rhs_params = rhs
- if internal_type == 'TimeField':
- return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
- return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
-
- def insert_statement(self, ignore_conflicts=False):
- return 'INSERT OR IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)
|