You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fields.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. from django.core.exceptions import FieldDoesNotExist
  2. from django.db.models.fields import NOT_PROVIDED
  3. from django.utils.functional import cached_property
  4. from .base import Operation
  5. from .utils import (
  6. ModelTuple, field_references_model, is_referenced_by_foreign_key,
  7. )
  8. class FieldOperation(Operation):
  9. def __init__(self, model_name, name, field=None):
  10. self.model_name = model_name
  11. self.name = name
  12. self.field = field
  13. @cached_property
  14. def model_name_lower(self):
  15. return self.model_name.lower()
  16. @cached_property
  17. def name_lower(self):
  18. return self.name.lower()
  19. def is_same_model_operation(self, operation):
  20. return self.model_name_lower == operation.model_name_lower
  21. def is_same_field_operation(self, operation):
  22. return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
  23. def references_model(self, name, app_label=None):
  24. name_lower = name.lower()
  25. if name_lower == self.model_name_lower:
  26. return True
  27. if self.field:
  28. return field_references_model(self.field, ModelTuple(app_label, name_lower))
  29. return False
  30. def references_field(self, model_name, name, app_label=None):
  31. model_name_lower = model_name.lower()
  32. # Check if this operation locally references the field.
  33. if model_name_lower == self.model_name_lower:
  34. if name == self.name:
  35. return True
  36. elif self.field and hasattr(self.field, 'from_fields') and name in self.field.from_fields:
  37. return True
  38. # Check if this operation remotely references the field.
  39. if self.field:
  40. model_tuple = ModelTuple(app_label, model_name_lower)
  41. remote_field = self.field.remote_field
  42. if remote_field:
  43. if (ModelTuple.from_model(remote_field.model) == model_tuple and
  44. (not hasattr(self.field, 'to_fields') or
  45. name in self.field.to_fields or None in self.field.to_fields)):
  46. return True
  47. through = getattr(remote_field, 'through', None)
  48. if (through and ModelTuple.from_model(through) == model_tuple and
  49. (getattr(remote_field, 'through_fields', None) is None or
  50. name in remote_field.through_fields)):
  51. return True
  52. return False
  53. def reduce(self, operation, app_label=None):
  54. return (
  55. super().reduce(operation, app_label=app_label) or
  56. not operation.references_field(self.model_name, self.name, app_label)
  57. )
  58. class AddField(FieldOperation):
  59. """Add a field to a model."""
  60. def __init__(self, model_name, name, field, preserve_default=True):
  61. self.preserve_default = preserve_default
  62. super().__init__(model_name, name, field)
  63. def deconstruct(self):
  64. kwargs = {
  65. 'model_name': self.model_name,
  66. 'name': self.name,
  67. 'field': self.field,
  68. }
  69. if self.preserve_default is not True:
  70. kwargs['preserve_default'] = self.preserve_default
  71. return (
  72. self.__class__.__name__,
  73. [],
  74. kwargs
  75. )
  76. def state_forwards(self, app_label, state):
  77. # If preserve default is off, don't use the default for future state
  78. if not self.preserve_default:
  79. field = self.field.clone()
  80. field.default = NOT_PROVIDED
  81. else:
  82. field = self.field
  83. state.models[app_label, self.model_name_lower].fields.append((self.name, field))
  84. # Delay rendering of relationships if it's not a relational field
  85. delay = not field.is_relation
  86. state.reload_model(app_label, self.model_name_lower, delay=delay)
  87. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  88. to_model = to_state.apps.get_model(app_label, self.model_name)
  89. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  90. from_model = from_state.apps.get_model(app_label, self.model_name)
  91. field = to_model._meta.get_field(self.name)
  92. if not self.preserve_default:
  93. field.default = self.field.default
  94. schema_editor.add_field(
  95. from_model,
  96. field,
  97. )
  98. if not self.preserve_default:
  99. field.default = NOT_PROVIDED
  100. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  101. from_model = from_state.apps.get_model(app_label, self.model_name)
  102. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  103. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  104. def describe(self):
  105. return "Add field %s to %s" % (self.name, self.model_name)
  106. def reduce(self, operation, app_label=None):
  107. if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
  108. if isinstance(operation, AlterField):
  109. return [
  110. AddField(
  111. model_name=self.model_name,
  112. name=operation.name,
  113. field=operation.field,
  114. ),
  115. ]
  116. elif isinstance(operation, RemoveField):
  117. return []
  118. elif isinstance(operation, RenameField):
  119. return [
  120. AddField(
  121. model_name=self.model_name,
  122. name=operation.new_name,
  123. field=self.field,
  124. ),
  125. ]
  126. return super().reduce(operation, app_label=app_label)
  127. class RemoveField(FieldOperation):
  128. """Remove a field from a model."""
  129. def deconstruct(self):
  130. kwargs = {
  131. 'model_name': self.model_name,
  132. 'name': self.name,
  133. }
  134. return (
  135. self.__class__.__name__,
  136. [],
  137. kwargs
  138. )
  139. def state_forwards(self, app_label, state):
  140. new_fields = []
  141. old_field = None
  142. for name, instance in state.models[app_label, self.model_name_lower].fields:
  143. if name != self.name:
  144. new_fields.append((name, instance))
  145. else:
  146. old_field = instance
  147. state.models[app_label, self.model_name_lower].fields = new_fields
  148. # Delay rendering of relationships if it's not a relational field
  149. delay = not old_field.is_relation
  150. state.reload_model(app_label, self.model_name_lower, delay=delay)
  151. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  152. from_model = from_state.apps.get_model(app_label, self.model_name)
  153. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  154. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  155. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  156. to_model = to_state.apps.get_model(app_label, self.model_name)
  157. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  158. from_model = from_state.apps.get_model(app_label, self.model_name)
  159. schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
  160. def describe(self):
  161. return "Remove field %s from %s" % (self.name, self.model_name)
  162. def reduce(self, operation, app_label=None):
  163. from .models import DeleteModel
  164. if isinstance(operation, DeleteModel) and operation.name_lower == self.model_name_lower:
  165. return [operation]
  166. return super().reduce(operation, app_label=app_label)
  167. class AlterField(FieldOperation):
  168. """
  169. Alter a field's database column (e.g. null, max_length) to the provided
  170. new field.
  171. """
  172. def __init__(self, model_name, name, field, preserve_default=True):
  173. self.preserve_default = preserve_default
  174. super().__init__(model_name, name, field)
  175. def deconstruct(self):
  176. kwargs = {
  177. 'model_name': self.model_name,
  178. 'name': self.name,
  179. 'field': self.field,
  180. }
  181. if self.preserve_default is not True:
  182. kwargs['preserve_default'] = self.preserve_default
  183. return (
  184. self.__class__.__name__,
  185. [],
  186. kwargs
  187. )
  188. def state_forwards(self, app_label, state):
  189. if not self.preserve_default:
  190. field = self.field.clone()
  191. field.default = NOT_PROVIDED
  192. else:
  193. field = self.field
  194. state.models[app_label, self.model_name_lower].fields = [
  195. (n, field if n == self.name else f)
  196. for n, f in
  197. state.models[app_label, self.model_name_lower].fields
  198. ]
  199. # TODO: investigate if old relational fields must be reloaded or if it's
  200. # sufficient if the new field is (#27737).
  201. # Delay rendering of relationships if it's not a relational field and
  202. # not referenced by a foreign key.
  203. delay = (
  204. not field.is_relation and
  205. not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name)
  206. )
  207. state.reload_model(app_label, self.model_name_lower, delay=delay)
  208. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  209. to_model = to_state.apps.get_model(app_label, self.model_name)
  210. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  211. from_model = from_state.apps.get_model(app_label, self.model_name)
  212. from_field = from_model._meta.get_field(self.name)
  213. to_field = to_model._meta.get_field(self.name)
  214. if not self.preserve_default:
  215. to_field.default = self.field.default
  216. schema_editor.alter_field(from_model, from_field, to_field)
  217. if not self.preserve_default:
  218. to_field.default = NOT_PROVIDED
  219. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  220. self.database_forwards(app_label, schema_editor, from_state, to_state)
  221. def describe(self):
  222. return "Alter field %s on %s" % (self.name, self.model_name)
  223. def reduce(self, operation, app_label=None):
  224. if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
  225. return [operation]
  226. elif isinstance(operation, RenameField) and self.is_same_field_operation(operation):
  227. return [
  228. operation,
  229. AlterField(
  230. model_name=self.model_name,
  231. name=operation.new_name,
  232. field=self.field,
  233. ),
  234. ]
  235. return super().reduce(operation, app_label=app_label)
  236. class RenameField(FieldOperation):
  237. """Rename a field on the model. Might affect db_column too."""
  238. def __init__(self, model_name, old_name, new_name):
  239. self.old_name = old_name
  240. self.new_name = new_name
  241. super().__init__(model_name, old_name)
  242. @cached_property
  243. def old_name_lower(self):
  244. return self.old_name.lower()
  245. @cached_property
  246. def new_name_lower(self):
  247. return self.new_name.lower()
  248. def deconstruct(self):
  249. kwargs = {
  250. 'model_name': self.model_name,
  251. 'old_name': self.old_name,
  252. 'new_name': self.new_name,
  253. }
  254. return (
  255. self.__class__.__name__,
  256. [],
  257. kwargs
  258. )
  259. def state_forwards(self, app_label, state):
  260. model_state = state.models[app_label, self.model_name_lower]
  261. # Rename the field
  262. fields = model_state.fields
  263. found = False
  264. delay = True
  265. for index, (name, field) in enumerate(fields):
  266. if not found and name == self.old_name:
  267. fields[index] = (self.new_name, field)
  268. found = True
  269. # Fix from_fields to refer to the new field.
  270. from_fields = getattr(field, 'from_fields', None)
  271. if from_fields:
  272. field.from_fields = tuple([
  273. self.new_name if from_field_name == self.old_name else from_field_name
  274. for from_field_name in from_fields
  275. ])
  276. # Delay rendering of relationships if it's not a relational
  277. # field and not referenced by a foreign key.
  278. delay = delay and (
  279. not field.is_relation and
  280. not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name)
  281. )
  282. if not found:
  283. raise FieldDoesNotExist(
  284. "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
  285. )
  286. # Fix index/unique_together to refer to the new field
  287. options = model_state.options
  288. for option in ('index_together', 'unique_together'):
  289. if option in options:
  290. options[option] = [
  291. [self.new_name if n == self.old_name else n for n in together]
  292. for together in options[option]
  293. ]
  294. # Fix to_fields to refer to the new field.
  295. model_tuple = app_label, self.model_name_lower
  296. for (model_app_label, model_name), model_state in state.models.items():
  297. for index, (name, field) in enumerate(model_state.fields):
  298. remote_field = field.remote_field
  299. if remote_field:
  300. remote_model_tuple = self._get_model_tuple(
  301. remote_field.model, model_app_label, model_name
  302. )
  303. if remote_model_tuple == model_tuple:
  304. if getattr(remote_field, 'field_name', None) == self.old_name:
  305. remote_field.field_name = self.new_name
  306. to_fields = getattr(field, 'to_fields', None)
  307. if to_fields:
  308. field.to_fields = tuple([
  309. self.new_name if to_field_name == self.old_name else to_field_name
  310. for to_field_name in to_fields
  311. ])
  312. state.reload_model(app_label, self.model_name_lower, delay=delay)
  313. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  314. to_model = to_state.apps.get_model(app_label, self.model_name)
  315. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  316. from_model = from_state.apps.get_model(app_label, self.model_name)
  317. schema_editor.alter_field(
  318. from_model,
  319. from_model._meta.get_field(self.old_name),
  320. to_model._meta.get_field(self.new_name),
  321. )
  322. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  323. to_model = to_state.apps.get_model(app_label, self.model_name)
  324. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  325. from_model = from_state.apps.get_model(app_label, self.model_name)
  326. schema_editor.alter_field(
  327. from_model,
  328. from_model._meta.get_field(self.new_name),
  329. to_model._meta.get_field(self.old_name),
  330. )
  331. def describe(self):
  332. return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
  333. def references_field(self, model_name, name, app_label=None):
  334. return self.references_model(model_name) and (
  335. name.lower() == self.old_name_lower or
  336. name.lower() == self.new_name_lower
  337. )
  338. def reduce(self, operation, app_label=None):
  339. if (isinstance(operation, RenameField) and
  340. self.is_same_model_operation(operation) and
  341. self.new_name_lower == operation.old_name_lower):
  342. return [
  343. RenameField(
  344. self.model_name,
  345. self.old_name,
  346. operation.new_name,
  347. ),
  348. ]
  349. # Skip `FieldOperation.reduce` as we want to run `references_field`
  350. # against self.new_name.
  351. return (
  352. super(FieldOperation, self).reduce(operation, app_label=app_label) or
  353. not operation.references_field(self.model_name, self.new_name, app_label)
  354. )