|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- from django.db.backends.utils import names_digest, split_identifier
- from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
- from django.db.models.functions import Collate
- from django.db.models.query_utils import Q
- from django.db.models.sql import Query
- from django.utils.functional import partition
-
- __all__ = ["Index"]
-
-
- class Index:
- suffix = "idx"
- # The max length of the name of the index (restricted to 30 for
- # cross-database compatibility with Oracle)
- max_name_length = 30
-
- def __init__(
- self,
- *expressions,
- fields=(),
- name=None,
- db_tablespace=None,
- opclasses=(),
- condition=None,
- include=None,
- ):
- if opclasses and not name:
- raise ValueError("An index must be named to use opclasses.")
- if not isinstance(condition, (type(None), Q)):
- raise ValueError("Index.condition must be a Q instance.")
- if condition and not name:
- raise ValueError("An index must be named to use condition.")
- if not isinstance(fields, (list, tuple)):
- raise ValueError("Index.fields must be a list or tuple.")
- if not isinstance(opclasses, (list, tuple)):
- raise ValueError("Index.opclasses must be a list or tuple.")
- if not expressions and not fields:
- raise ValueError(
- "At least one field or expression is required to define an index."
- )
- if expressions and fields:
- raise ValueError(
- "Index.fields and expressions are mutually exclusive.",
- )
- if expressions and not name:
- raise ValueError("An index must be named to use expressions.")
- if expressions and opclasses:
- raise ValueError(
- "Index.opclasses cannot be used with expressions. Use "
- "django.contrib.postgres.indexes.OpClass() instead."
- )
- if opclasses and len(fields) != len(opclasses):
- raise ValueError(
- "Index.fields and Index.opclasses must have the same number of "
- "elements."
- )
- if fields and not all(isinstance(field, str) for field in fields):
- raise ValueError("Index.fields must contain only strings with field names.")
- if include and not name:
- raise ValueError("A covering index must be named.")
- if not isinstance(include, (type(None), list, tuple)):
- raise ValueError("Index.include must be a list or tuple.")
- self.fields = list(fields)
- # A list of 2-tuple with the field name and ordering ('' or 'DESC').
- self.fields_orders = [
- (field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
- for field_name in self.fields
- ]
- self.name = name or ""
- self.db_tablespace = db_tablespace
- self.opclasses = opclasses
- self.condition = condition
- self.include = tuple(include) if include else ()
- self.expressions = tuple(
- F(expression) if isinstance(expression, str) else expression
- for expression in expressions
- )
-
- @property
- def contains_expressions(self):
- return bool(self.expressions)
-
- def _get_condition_sql(self, model, schema_editor):
- if self.condition is None:
- return None
- query = Query(model=model, alias_cols=False)
- where = query.build_where(self.condition)
- compiler = query.get_compiler(connection=schema_editor.connection)
- sql, params = where.as_sql(compiler, schema_editor.connection)
- return sql % tuple(schema_editor.quote_value(p) for p in params)
-
- def create_sql(self, model, schema_editor, using="", **kwargs):
- include = [
- model._meta.get_field(field_name).column for field_name in self.include
- ]
- condition = self._get_condition_sql(model, schema_editor)
- if self.expressions:
- index_expressions = []
- for expression in self.expressions:
- index_expression = IndexExpression(expression)
- index_expression.set_wrapper_classes(schema_editor.connection)
- index_expressions.append(index_expression)
- expressions = ExpressionList(*index_expressions).resolve_expression(
- Query(model, alias_cols=False),
- )
- fields = None
- col_suffixes = None
- else:
- fields = [
- model._meta.get_field(field_name)
- for field_name, _ in self.fields_orders
- ]
- if schema_editor.connection.features.supports_index_column_ordering:
- col_suffixes = [order[1] for order in self.fields_orders]
- else:
- col_suffixes = [""] * len(self.fields_orders)
- expressions = None
- return schema_editor._create_index_sql(
- model,
- fields=fields,
- name=self.name,
- using=using,
- db_tablespace=self.db_tablespace,
- col_suffixes=col_suffixes,
- opclasses=self.opclasses,
- condition=condition,
- include=include,
- expressions=expressions,
- **kwargs,
- )
-
- def remove_sql(self, model, schema_editor, **kwargs):
- return schema_editor._delete_index_sql(model, self.name, **kwargs)
-
- def deconstruct(self):
- path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
- path = path.replace("django.db.models.indexes", "django.db.models")
- kwargs = {"name": self.name}
- if self.fields:
- kwargs["fields"] = self.fields
- if self.db_tablespace is not None:
- kwargs["db_tablespace"] = self.db_tablespace
- if self.opclasses:
- kwargs["opclasses"] = self.opclasses
- if self.condition:
- kwargs["condition"] = self.condition
- if self.include:
- kwargs["include"] = self.include
- return (path, self.expressions, kwargs)
-
- def clone(self):
- """Create a copy of this Index."""
- _, args, kwargs = self.deconstruct()
- return self.__class__(*args, **kwargs)
-
- def set_name_with_model(self, model):
- """
- Generate a unique name for the index.
-
- The name is divided into 3 parts - table name (12 chars), field name
- (8 chars) and unique hash + suffix (10 chars). Each part is made to
- fit its size by truncating the excess length.
- """
- _, table_name = split_identifier(model._meta.db_table)
- column_names = [
- model._meta.get_field(field_name).column
- for field_name, order in self.fields_orders
- ]
- column_names_with_order = [
- (("-%s" if order else "%s") % column_name)
- for column_name, (field_name, order) in zip(
- column_names, self.fields_orders
- )
- ]
- # The length of the parts of the name is based on the default max
- # length of 30 characters.
- hash_data = [table_name] + column_names_with_order + [self.suffix]
- self.name = "%s_%s_%s" % (
- table_name[:11],
- column_names[0][:7],
- "%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
- )
- if len(self.name) > self.max_name_length:
- raise ValueError(
- "Index too long for multiple database support. Is self.suffix "
- "longer than 3 characters?"
- )
- if self.name[0] == "_" or self.name[0].isdigit():
- self.name = "D%s" % self.name[1:]
-
- def __repr__(self):
- return "<%s:%s%s%s%s%s%s%s>" % (
- self.__class__.__qualname__,
- "" if not self.fields else " fields=%s" % repr(self.fields),
- "" if not self.expressions else " expressions=%s" % repr(self.expressions),
- "" if not self.name else " name=%s" % repr(self.name),
- ""
- if self.db_tablespace is None
- else " db_tablespace=%s" % repr(self.db_tablespace),
- "" if self.condition is None else " condition=%s" % self.condition,
- "" if not self.include else " include=%s" % repr(self.include),
- "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
- )
-
- def __eq__(self, other):
- if self.__class__ == other.__class__:
- return self.deconstruct() == other.deconstruct()
- return NotImplemented
-
-
- class IndexExpression(Func):
- """Order and wrap expressions for CREATE INDEX statements."""
-
- template = "%(expressions)s"
- wrapper_classes = (OrderBy, Collate)
-
- def set_wrapper_classes(self, connection=None):
- # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
- if connection and connection.features.collate_as_index_expression:
- self.wrapper_classes = tuple(
- [
- wrapper_cls
- for wrapper_cls in self.wrapper_classes
- if wrapper_cls is not Collate
- ]
- )
-
- @classmethod
- def register_wrappers(cls, *wrapper_classes):
- cls.wrapper_classes = wrapper_classes
-
- def resolve_expression(
- self,
- query=None,
- allow_joins=True,
- reuse=None,
- summarize=False,
- for_save=False,
- ):
- expressions = list(self.flatten())
- # Split expressions and wrappers.
- index_expressions, wrappers = partition(
- lambda e: isinstance(e, self.wrapper_classes),
- expressions,
- )
- wrapper_types = [type(wrapper) for wrapper in wrappers]
- if len(wrapper_types) != len(set(wrapper_types)):
- raise ValueError(
- "Multiple references to %s can't be used in an indexed "
- "expression."
- % ", ".join(
- [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
- )
- )
- if expressions[1 : len(wrappers) + 1] != wrappers:
- raise ValueError(
- "%s must be topmost expressions in an indexed expression."
- % ", ".join(
- [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
- )
- )
- # Wrap expressions in parentheses if they are not column references.
- root_expression = index_expressions[1]
- resolve_root_expression = root_expression.resolve_expression(
- query,
- allow_joins,
- reuse,
- summarize,
- for_save,
- )
- if not isinstance(resolve_root_expression, Col):
- root_expression = Func(root_expression, template="(%(expressions)s)")
-
- if wrappers:
- # Order wrappers and set their expressions.
- wrappers = sorted(
- wrappers,
- key=lambda w: self.wrapper_classes.index(type(w)),
- )
- wrappers = [wrapper.copy() for wrapper in wrappers]
- for i, wrapper in enumerate(wrappers[:-1]):
- wrapper.set_source_expressions([wrappers[i + 1]])
- # Set the root expression on the deepest wrapper.
- wrappers[-1].set_source_expressions([root_expression])
- self.set_source_expressions([wrappers[0]])
- else:
- # Use the root expression, if there are no wrappers.
- self.set_source_expressions([root_expression])
- return super().resolve_expression(
- query, allow_joins, reuse, summarize, for_save
- )
-
- def as_sqlite(self, compiler, connection, **extra_context):
- # Casting to numeric is unnecessary.
- return self.as_sql(compiler, connection, **extra_context)
|