123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- from psycopg2.extras import Inet
-
- from django.conf import settings
- from django.db import NotSupportedError
- from django.db.backends.base.operations import BaseDatabaseOperations
-
-
- class DatabaseOperations(BaseDatabaseOperations):
- cast_char_field_without_max_length = 'varchar'
- explain_prefix = 'EXPLAIN'
- cast_data_types = {
- 'AutoField': 'integer',
- 'BigAutoField': 'bigint',
- }
-
- def unification_cast_sql(self, output_field):
- internal_type = output_field.get_internal_type()
- if internal_type in ("GenericIPAddressField", "IPAddressField", "TimeField", "UUIDField"):
- # PostgreSQL will resolve a union as type 'text' if input types are
- # 'unknown'.
- # https://www.postgresql.org/docs/current/typeconv-union-case.html
- # These fields cannot be implicitly cast back in the default
- # PostgreSQL configuration so we need to explicitly cast them.
- # We must also remove components of the type within brackets:
- # varchar(255) -> varchar.
- return 'CAST(%%s AS %s)' % output_field.db_type(self.connection).split('(')[0]
- return '%s'
-
- def date_extract_sql(self, lookup_type, field_name):
- # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
- if lookup_type == 'week_day':
- # For consistency across backends, we return Sunday=1, Saturday=7.
- return "EXTRACT('dow' FROM %s) + 1" % field_name
- elif lookup_type == 'iso_year':
- return "EXTRACT('isoyear' FROM %s)" % field_name
- else:
- return "EXTRACT('%s' FROM %s)" % (lookup_type, field_name)
-
- def date_trunc_sql(self, lookup_type, field_name):
- # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
- return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
-
- def _convert_field_to_tz(self, field_name, tzname):
- if settings.USE_TZ:
- field_name = "%s AT TIME ZONE '%s'" % (field_name, tzname)
- return field_name
-
- def datetime_cast_date_sql(self, field_name, tzname):
- field_name = self._convert_field_to_tz(field_name, tzname)
- return '(%s)::date' % field_name
-
- def datetime_cast_time_sql(self, field_name, tzname):
- field_name = self._convert_field_to_tz(field_name, tzname)
- return '(%s)::time' % field_name
-
- def datetime_extract_sql(self, lookup_type, field_name, tzname):
- field_name = self._convert_field_to_tz(field_name, tzname)
- return self.date_extract_sql(lookup_type, field_name)
-
- def datetime_trunc_sql(self, lookup_type, field_name, tzname):
- field_name = self._convert_field_to_tz(field_name, tzname)
- # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
- return "DATE_TRUNC('%s', %s)" % (lookup_type, field_name)
-
- def time_trunc_sql(self, lookup_type, field_name):
- return "DATE_TRUNC('%s', %s)::time" % (lookup_type, field_name)
-
- def deferrable_sql(self):
- return " DEFERRABLE INITIALLY DEFERRED"
-
- def fetch_returned_insert_ids(self, cursor):
- """
- Given a cursor object that has just performed an INSERT...RETURNING
- statement into a table that has an auto-incrementing ID, return the
- list of newly created IDs.
- """
- return [item[0] for item in cursor.fetchall()]
-
- def lookup_cast(self, lookup_type, internal_type=None):
- lookup = '%s'
-
- # Cast text lookups to text to allow things like filter(x__contains=4)
- if lookup_type in ('iexact', 'contains', 'icontains', 'startswith',
- 'istartswith', 'endswith', 'iendswith', 'regex', 'iregex'):
- if internal_type in ('IPAddressField', 'GenericIPAddressField'):
- lookup = "HOST(%s)"
- elif internal_type in ('CICharField', 'CIEmailField', 'CITextField'):
- lookup = '%s::citext'
- else:
- lookup = "%s::text"
-
- # Use UPPER(x) for case-insensitive lookups; it's faster.
- if lookup_type in ('iexact', 'icontains', 'istartswith', 'iendswith'):
- lookup = 'UPPER(%s)' % lookup
-
- return lookup
-
- def no_limit_value(self):
- return None
-
- def prepare_sql_script(self, sql):
- return [sql]
-
- def quote_name(self, name):
- if name.startswith('"') and name.endswith('"'):
- return name # Quoting once is enough.
- return '"%s"' % name
-
- def set_time_zone_sql(self):
- return "SET TIME ZONE %s"
-
- def sql_flush(self, style, tables, sequences, allow_cascade=False):
- if tables:
- # Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows
- # us to truncate tables referenced by a foreign key in any other
- # table.
- tables_sql = ', '.join(
- style.SQL_FIELD(self.quote_name(table)) for table in tables)
- if allow_cascade:
- sql = ['%s %s %s;' % (
- style.SQL_KEYWORD('TRUNCATE'),
- tables_sql,
- style.SQL_KEYWORD('CASCADE'),
- )]
- else:
- sql = ['%s %s;' % (
- style.SQL_KEYWORD('TRUNCATE'),
- tables_sql,
- )]
- sql.extend(self.sequence_reset_by_name_sql(style, sequences))
- return sql
- else:
- return []
-
- def sequence_reset_by_name_sql(self, style, sequences):
- # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
- # to reset sequence indices
- sql = []
- for sequence_info in sequences:
- table_name = sequence_info['table']
- # 'id' will be the case if it's an m2m using an autogenerated
- # intermediate table (see BaseDatabaseIntrospection.sequence_list).
- column_name = sequence_info['column'] or 'id'
- sql.append("%s setval(pg_get_serial_sequence('%s','%s'), 1, false);" % (
- style.SQL_KEYWORD('SELECT'),
- style.SQL_TABLE(self.quote_name(table_name)),
- style.SQL_FIELD(column_name),
- ))
- return sql
-
- def tablespace_sql(self, tablespace, inline=False):
- if inline:
- return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
- else:
- return "TABLESPACE %s" % self.quote_name(tablespace)
-
- def sequence_reset_sql(self, style, model_list):
- from django.db import models
- output = []
- qn = self.quote_name
- for model in model_list:
- # Use `coalesce` to set the sequence for each model to the max pk value if there are records,
- # or 1 if there are none. Set the `is_called` property (the third argument to `setval`) to true
- # if there are records (as the max pk value is already in use), otherwise set it to false.
- # Use pg_get_serial_sequence to get the underlying sequence name from the table name
- # and column name (available since PostgreSQL 8)
-
- for f in model._meta.local_fields:
- if isinstance(f, models.AutoField):
- output.append(
- "%s setval(pg_get_serial_sequence('%s','%s'), "
- "coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
- style.SQL_KEYWORD('SELECT'),
- style.SQL_TABLE(qn(model._meta.db_table)),
- style.SQL_FIELD(f.column),
- style.SQL_FIELD(qn(f.column)),
- style.SQL_FIELD(qn(f.column)),
- style.SQL_KEYWORD('IS NOT'),
- style.SQL_KEYWORD('FROM'),
- style.SQL_TABLE(qn(model._meta.db_table)),
- )
- )
- break # Only one AutoField is allowed per model, so don't bother continuing.
- for f in model._meta.many_to_many:
- if not f.remote_field.through:
- output.append(
- "%s setval(pg_get_serial_sequence('%s','%s'), "
- "coalesce(max(%s), 1), max(%s) %s null) %s %s;" % (
- style.SQL_KEYWORD('SELECT'),
- style.SQL_TABLE(qn(f.m2m_db_table())),
- style.SQL_FIELD('id'),
- style.SQL_FIELD(qn('id')),
- style.SQL_FIELD(qn('id')),
- style.SQL_KEYWORD('IS NOT'),
- style.SQL_KEYWORD('FROM'),
- style.SQL_TABLE(qn(f.m2m_db_table()))
- )
- )
- return output
-
- def prep_for_iexact_query(self, x):
- return x
-
- def max_name_length(self):
- """
- Return the maximum length of an identifier.
-
- The maximum length of an identifier is 63 by default, but can be
- changed by recompiling PostgreSQL after editing the NAMEDATALEN
- macro in src/include/pg_config_manual.h.
-
- This implementation returns 63, but can be overridden by a custom
- database backend that inherits most of its behavior from this one.
- """
- return 63
-
- def distinct_sql(self, fields, params):
- if fields:
- params = [param for param_list in params for param in param_list]
- return (['DISTINCT ON (%s)' % ', '.join(fields)], params)
- else:
- return ['DISTINCT'], []
-
- def last_executed_query(self, cursor, sql, params):
- # http://initd.org/psycopg/docs/cursor.html#cursor.query
- # The query attribute is a Psycopg extension to the DB API 2.0.
- if cursor.query is not None:
- return cursor.query.decode()
- return None
-
- def return_insert_id(self):
- return "RETURNING %s", ()
-
- def bulk_insert_sql(self, fields, placeholder_rows):
- placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
- values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
- return "VALUES " + values_sql
-
- def adapt_datefield_value(self, value):
- return value
-
- def adapt_datetimefield_value(self, value):
- return value
-
- def adapt_timefield_value(self, value):
- return value
-
- def adapt_ipaddressfield_value(self, value):
- if value:
- return Inet(value)
- return None
-
- def subtract_temporals(self, internal_type, lhs, rhs):
- if internal_type == 'DateField':
- lhs_sql, lhs_params = lhs
- rhs_sql, rhs_params = rhs
- return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), lhs_params + rhs_params
- return super().subtract_temporals(internal_type, lhs, rhs)
-
- def window_frame_range_start_end(self, start=None, end=None):
- start_, end_ = super().window_frame_range_start_end(start, end)
- if (start and start < 0) or (end and end > 0):
- raise NotSupportedError(
- 'PostgreSQL only supports UNBOUNDED together with PRECEDING '
- 'and FOLLOWING.'
- )
- return start_, end_
-
- def explain_query_prefix(self, format=None, **options):
- prefix = super().explain_query_prefix(format)
- extra = {}
- if format:
- extra['FORMAT'] = format
- if options:
- extra.update({
- name.upper(): 'true' if value else 'false'
- for name, value in options.items()
- })
- if extra:
- prefix += ' (%s)' % ', '.join('%s %s' % i for i in extra.items())
- return prefix
-
- def ignore_conflicts_suffix_sql(self, ignore_conflicts=None):
- return 'ON CONFLICT DO NOTHING' if ignore_conflicts else super().ignore_conflicts_suffix_sql(ignore_conflicts)
|