Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

jsonb.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import json
  2. from psycopg2.extras import Json
  3. from django.contrib.postgres import forms, lookups
  4. from django.core import exceptions
  5. from django.db.models import (
  6. Field, TextField, Transform, lookups as builtin_lookups,
  7. )
  8. from django.utils.translation import gettext_lazy as _
  9. __all__ = ['JSONField']
  10. class JsonAdapter(Json):
  11. """
  12. Customized psycopg2.extras.Json to allow for a custom encoder.
  13. """
  14. def __init__(self, adapted, dumps=None, encoder=None):
  15. self.encoder = encoder
  16. super().__init__(adapted, dumps=dumps)
  17. def dumps(self, obj):
  18. options = {'cls': self.encoder} if self.encoder else {}
  19. return json.dumps(obj, **options)
  20. class JSONField(Field):
  21. empty_strings_allowed = False
  22. description = _('A JSON object')
  23. default_error_messages = {
  24. 'invalid': _("Value must be valid JSON."),
  25. }
  26. def __init__(self, verbose_name=None, name=None, encoder=None, **kwargs):
  27. if encoder and not callable(encoder):
  28. raise ValueError("The encoder parameter must be a callable object.")
  29. self.encoder = encoder
  30. super().__init__(verbose_name, name, **kwargs)
  31. def db_type(self, connection):
  32. return 'jsonb'
  33. def deconstruct(self):
  34. name, path, args, kwargs = super().deconstruct()
  35. if self.encoder is not None:
  36. kwargs['encoder'] = self.encoder
  37. return name, path, args, kwargs
  38. def get_transform(self, name):
  39. transform = super().get_transform(name)
  40. if transform:
  41. return transform
  42. return KeyTransformFactory(name)
  43. def get_prep_value(self, value):
  44. if value is not None:
  45. return JsonAdapter(value, encoder=self.encoder)
  46. return value
  47. def validate(self, value, model_instance):
  48. super().validate(value, model_instance)
  49. options = {'cls': self.encoder} if self.encoder else {}
  50. try:
  51. json.dumps(value, **options)
  52. except TypeError:
  53. raise exceptions.ValidationError(
  54. self.error_messages['invalid'],
  55. code='invalid',
  56. params={'value': value},
  57. )
  58. def value_to_string(self, obj):
  59. return self.value_from_object(obj)
  60. def formfield(self, **kwargs):
  61. defaults = {'form_class': forms.JSONField}
  62. defaults.update(kwargs)
  63. return super().formfield(**defaults)
  64. JSONField.register_lookup(lookups.DataContains)
  65. JSONField.register_lookup(lookups.ContainedBy)
  66. JSONField.register_lookup(lookups.HasKey)
  67. JSONField.register_lookup(lookups.HasKeys)
  68. JSONField.register_lookup(lookups.HasAnyKeys)
  69. class KeyTransform(Transform):
  70. operator = '->'
  71. nested_operator = '#>'
  72. def __init__(self, key_name, *args, **kwargs):
  73. super().__init__(*args, **kwargs)
  74. self.key_name = key_name
  75. def as_sql(self, compiler, connection):
  76. key_transforms = [self.key_name]
  77. previous = self.lhs
  78. while isinstance(previous, KeyTransform):
  79. key_transforms.insert(0, previous.key_name)
  80. previous = previous.lhs
  81. lhs, params = compiler.compile(previous)
  82. if len(key_transforms) > 1:
  83. return "(%s %s %%s)" % (lhs, self.nested_operator), [key_transforms] + params
  84. try:
  85. int(self.key_name)
  86. except ValueError:
  87. lookup = "'%s'" % self.key_name
  88. else:
  89. lookup = "%s" % self.key_name
  90. return "(%s %s %s)" % (lhs, self.operator, lookup), params
  91. class KeyTextTransform(KeyTransform):
  92. operator = '->>'
  93. nested_operator = '#>>'
  94. output_field = TextField()
  95. class KeyTransformTextLookupMixin:
  96. """
  97. Mixin for combining with a lookup expecting a text lhs from a JSONField
  98. key lookup. Make use of the ->> operator instead of casting key values to
  99. text and performing the lookup on the resulting representation.
  100. """
  101. def __init__(self, key_transform, *args, **kwargs):
  102. assert isinstance(key_transform, KeyTransform)
  103. key_text_transform = KeyTextTransform(
  104. key_transform.key_name, *key_transform.source_expressions, **key_transform.extra
  105. )
  106. super().__init__(key_text_transform, *args, **kwargs)
  107. class KeyTransformIExact(KeyTransformTextLookupMixin, builtin_lookups.IExact):
  108. pass
  109. class KeyTransformIContains(KeyTransformTextLookupMixin, builtin_lookups.IContains):
  110. pass
  111. class KeyTransformStartsWith(KeyTransformTextLookupMixin, builtin_lookups.StartsWith):
  112. pass
  113. class KeyTransformIStartsWith(KeyTransformTextLookupMixin, builtin_lookups.IStartsWith):
  114. pass
  115. class KeyTransformEndsWith(KeyTransformTextLookupMixin, builtin_lookups.EndsWith):
  116. pass
  117. class KeyTransformIEndsWith(KeyTransformTextLookupMixin, builtin_lookups.IEndsWith):
  118. pass
  119. class KeyTransformRegex(KeyTransformTextLookupMixin, builtin_lookups.Regex):
  120. pass
  121. class KeyTransformIRegex(KeyTransformTextLookupMixin, builtin_lookups.IRegex):
  122. pass
  123. KeyTransform.register_lookup(KeyTransformIExact)
  124. KeyTransform.register_lookup(KeyTransformIContains)
  125. KeyTransform.register_lookup(KeyTransformStartsWith)
  126. KeyTransform.register_lookup(KeyTransformIStartsWith)
  127. KeyTransform.register_lookup(KeyTransformEndsWith)
  128. KeyTransform.register_lookup(KeyTransformIEndsWith)
  129. KeyTransform.register_lookup(KeyTransformRegex)
  130. KeyTransform.register_lookup(KeyTransformIRegex)
  131. class KeyTransformFactory:
  132. def __init__(self, key_name):
  133. self.key_name = key_name
  134. def __call__(self, *args, **kwargs):
  135. return KeyTransform(self.key_name, *args, **kwargs)