Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fields.py 15KB


  1. from django.core.exceptions import FieldDoesNotExist
  2. from django.db.models.fields import NOT_PROVIDED
  3. from django.utils.functional import cached_property
  4. from .base import Operation
  5. from .utils import is_referenced_by_foreign_key
  6. class FieldOperation(Operation):
  7. def __init__(self, model_name, name):
  8. self.model_name = model_name
  9. self.name = name
  10. @cached_property
  11. def model_name_lower(self):
  12. return self.model_name.lower()
  13. @cached_property
  14. def name_lower(self):
  15. return self.name.lower()
  16. def is_same_model_operation(self, operation):
  17. return self.model_name_lower == operation.model_name_lower
  18. def is_same_field_operation(self, operation):
  19. return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower
  20. def references_model(self, name, app_label=None):
  21. return name.lower() == self.model_name_lower
  22. def references_field(self, model_name, name, app_label=None):
  23. return self.references_model(model_name) and name.lower() == self.name_lower
  24. def reduce(self, operation, in_between, app_label=None):
  25. return (
  26. super().reduce(operation, in_between, app_label=app_label) or
  27. not operation.references_field(self.model_name, self.name, app_label)
  28. )
  29. class AddField(FieldOperation):
  30. """Add a field to a model."""
  31. def __init__(self, model_name, name, field, preserve_default=True):
  32. self.field = field
  33. self.preserve_default = preserve_default
  34. super().__init__(model_name, name)
  35. def deconstruct(self):
  36. kwargs = {
  37. 'model_name': self.model_name,
  38. 'name': self.name,
  39. 'field': self.field,
  40. }
  41. if self.preserve_default is not True:
  42. kwargs['preserve_default'] = self.preserve_default
  43. return (
  44. self.__class__.__name__,
  45. [],
  46. kwargs
  47. )
  48. def state_forwards(self, app_label, state):
  49. # If preserve default is off, don't use the default for future state
  50. if not self.preserve_default:
  51. field = self.field.clone()
  52. field.default = NOT_PROVIDED
  53. else:
  54. field = self.field
  55. state.models[app_label, self.model_name_lower].fields.append((self.name, field))
  56. # Delay rendering of relationships if it's not a relational field
  57. delay = not field.is_relation
  58. state.reload_model(app_label, self.model_name_lower, delay=delay)
  59. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  60. to_model = to_state.apps.get_model(app_label, self.model_name)
  61. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  62. from_model = from_state.apps.get_model(app_label, self.model_name)
  63. field = to_model._meta.get_field(self.name)
  64. if not self.preserve_default:
  65. field.default = self.field.default
  66. schema_editor.add_field(
  67. from_model,
  68. field,
  69. )
  70. if not self.preserve_default:
  71. field.default = NOT_PROVIDED
  72. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  73. from_model = from_state.apps.get_model(app_label, self.model_name)
  74. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  75. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  76. def describe(self):
  77. return "Add field %s to %s" % (self.name, self.model_name)
  78. def reduce(self, operation, in_between, app_label=None):
  79. if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation):
  80. if isinstance(operation, AlterField):
  81. return [
  82. AddField(
  83. model_name=self.model_name,
  84. name=operation.name,
  85. field=operation.field,
  86. ),
  87. ]
  88. elif isinstance(operation, RemoveField):
  89. return []
  90. elif isinstance(operation, RenameField):
  91. return [
  92. AddField(
  93. model_name=self.model_name,
  94. name=operation.new_name,
  95. field=self.field,
  96. ),
  97. ]
  98. return super().reduce(operation, in_between, app_label=app_label)
  99. class RemoveField(FieldOperation):
  100. """Remove a field from a model."""
  101. def deconstruct(self):
  102. kwargs = {
  103. 'model_name': self.model_name,
  104. 'name': self.name,
  105. }
  106. return (
  107. self.__class__.__name__,
  108. [],
  109. kwargs
  110. )
  111. def state_forwards(self, app_label, state):
  112. new_fields = []
  113. old_field = None
  114. for name, instance in state.models[app_label, self.model_name_lower].fields:
  115. if name != self.name:
  116. new_fields.append((name, instance))
  117. else:
  118. old_field = instance
  119. state.models[app_label, self.model_name_lower].fields = new_fields
  120. # Delay rendering of relationships if it's not a relational field
  121. delay = not old_field.is_relation
  122. state.reload_model(app_label, self.model_name_lower, delay=delay)
  123. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  124. from_model = from_state.apps.get_model(app_label, self.model_name)
  125. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  126. schema_editor.remove_field(from_model, from_model._meta.get_field(self.name))
  127. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  128. to_model = to_state.apps.get_model(app_label, self.model_name)
  129. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  130. from_model = from_state.apps.get_model(app_label, self.model_name)
  131. schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
  132. def describe(self):
  133. return "Remove field %s from %s" % (self.name, self.model_name)
  134. class AlterField(FieldOperation):
  135. """
  136. Alter a field's database column (e.g. null, max_length) to the provided
  137. new field.
  138. """
  139. def __init__(self, model_name, name, field, preserve_default=True):
  140. self.field = field
  141. self.preserve_default = preserve_default
  142. super().__init__(model_name, name)
  143. def deconstruct(self):
  144. kwargs = {
  145. 'model_name': self.model_name,
  146. 'name': self.name,
  147. 'field': self.field,
  148. }
  149. if self.preserve_default is not True:
  150. kwargs['preserve_default'] = self.preserve_default
  151. return (
  152. self.__class__.__name__,
  153. [],
  154. kwargs
  155. )
  156. def state_forwards(self, app_label, state):
  157. if not self.preserve_default:
  158. field = self.field.clone()
  159. field.default = NOT_PROVIDED
  160. else:
  161. field = self.field
  162. state.models[app_label, self.model_name_lower].fields = [
  163. (n, field if n == self.name else f)
  164. for n, f in
  165. state.models[app_label, self.model_name_lower].fields
  166. ]
  167. # TODO: investigate if old relational fields must be reloaded or if it's
  168. # sufficient if the new field is (#27737).
  169. # Delay rendering of relationships if it's not a relational field and
  170. # not referenced by a foreign key.
  171. delay = (
  172. not field.is_relation and
  173. not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name)
  174. )
  175. state.reload_model(app_label, self.model_name_lower, delay=delay)
  176. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  177. to_model = to_state.apps.get_model(app_label, self.model_name)
  178. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  179. from_model = from_state.apps.get_model(app_label, self.model_name)
  180. from_field = from_model._meta.get_field(self.name)
  181. to_field = to_model._meta.get_field(self.name)
  182. if not self.preserve_default:
  183. to_field.default = self.field.default
  184. schema_editor.alter_field(from_model, from_field, to_field)
  185. if not self.preserve_default:
  186. to_field.default = NOT_PROVIDED
  187. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  188. self.database_forwards(app_label, schema_editor, from_state, to_state)
  189. def describe(self):
  190. return "Alter field %s on %s" % (self.name, self.model_name)
  191. def reduce(self, operation, in_between, app_label=None):
  192. if isinstance(operation, RemoveField) and self.is_same_field_operation(operation):
  193. return [operation]
  194. elif isinstance(operation, RenameField) and self.is_same_field_operation(operation):
  195. return [
  196. operation,
  197. AlterField(
  198. model_name=self.model_name,
  199. name=operation.new_name,
  200. field=self.field,
  201. ),
  202. ]
  203. return super().reduce(operation, in_between, app_label=app_label)
  204. class RenameField(FieldOperation):
  205. """Rename a field on the model. Might affect db_column too."""
  206. def __init__(self, model_name, old_name, new_name):
  207. self.old_name = old_name
  208. self.new_name = new_name
  209. super().__init__(model_name, old_name)
  210. @cached_property
  211. def old_name_lower(self):
  212. return self.old_name.lower()
  213. @cached_property
  214. def new_name_lower(self):
  215. return self.new_name.lower()
  216. def deconstruct(self):
  217. kwargs = {
  218. 'model_name': self.model_name,
  219. 'old_name': self.old_name,
  220. 'new_name': self.new_name,
  221. }
  222. return (
  223. self.__class__.__name__,
  224. [],
  225. kwargs
  226. )
  227. def state_forwards(self, app_label, state):
  228. model_state = state.models[app_label, self.model_name_lower]
  229. # Rename the field
  230. fields = model_state.fields
  231. found = False
  232. delay = True
  233. for index, (name, field) in enumerate(fields):
  234. if not found and name == self.old_name:
  235. fields[index] = (self.new_name, field)
  236. found = True
  237. # Fix from_fields to refer to the new field.
  238. from_fields = getattr(field, 'from_fields', None)
  239. if from_fields:
  240. field.from_fields = tuple([
  241. self.new_name if from_field_name == self.old_name else from_field_name
  242. for from_field_name in from_fields
  243. ])
  244. # Delay rendering of relationships if it's not a relational
  245. # field and not referenced by a foreign key.
  246. delay = delay and (
  247. not field.is_relation and
  248. not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name)
  249. )
  250. if not found:
  251. raise FieldDoesNotExist(
  252. "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name)
  253. )
  254. # Fix index/unique_together to refer to the new field
  255. options = model_state.options
  256. for option in ('index_together', 'unique_together'):
  257. if option in options:
  258. options[option] = [
  259. [self.new_name if n == self.old_name else n for n in together]
  260. for together in options[option]
  261. ]
  262. # Fix to_fields to refer to the new field.
  263. model_tuple = app_label, self.model_name_lower
  264. for (model_app_label, model_name), model_state in state.models.items():
  265. for index, (name, field) in enumerate(model_state.fields):
  266. remote_field = field.remote_field
  267. if remote_field:
  268. remote_model_tuple = self._get_model_tuple(
  269. remote_field.model, model_app_label, model_name
  270. )
  271. if remote_model_tuple == model_tuple:
  272. if getattr(remote_field, 'field_name', None) == self.old_name:
  273. remote_field.field_name = self.new_name
  274. to_fields = getattr(field, 'to_fields', None)
  275. if to_fields:
  276. field.to_fields = tuple([
  277. self.new_name if to_field_name == self.old_name else to_field_name
  278. for to_field_name in to_fields
  279. ])
  280. state.reload_model(app_label, self.model_name_lower, delay=delay)
  281. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  282. to_model = to_state.apps.get_model(app_label, self.model_name)
  283. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  284. from_model = from_state.apps.get_model(app_label, self.model_name)
  285. schema_editor.alter_field(
  286. from_model,
  287. from_model._meta.get_field(self.old_name),
  288. to_model._meta.get_field(self.new_name),
  289. )
  290. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  291. to_model = to_state.apps.get_model(app_label, self.model_name)
  292. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  293. from_model = from_state.apps.get_model(app_label, self.model_name)
  294. schema_editor.alter_field(
  295. from_model,
  296. from_model._meta.get_field(self.new_name),
  297. to_model._meta.get_field(self.old_name),
  298. )
  299. def describe(self):
  300. return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
  301. def references_field(self, model_name, name, app_label=None):
  302. return self.references_model(model_name) and (
  303. name.lower() == self.old_name_lower or
  304. name.lower() == self.new_name_lower
  305. )
  306. def reduce(self, operation, in_between, app_label=None):
  307. if (isinstance(operation, RenameField) and
  308. self.is_same_model_operation(operation) and
  309. self.new_name_lower == operation.old_name_lower):
  310. return [
  311. RenameField(
  312. self.model_name,
  313. self.old_name,
  314. operation.new_name,
  315. ),
  316. ]
  317. # Skip `FieldOperation.reduce` as we want to run `references_field`
  318. # against self.new_name.
  319. return (
  320. super(FieldOperation, self).reduce(operation, in_between, app_label=app_label) or
  321. not operation.references_field(self.model_name, self.new_name, app_label)
  322. )