You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

jsonb.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import json
  2. from psycopg2.extras import Json
  3. from django.contrib.postgres import forms, lookups
  4. from django.core import exceptions
  5. from django.db.models import (
  6. Field, TextField, Transform, lookups as builtin_lookups,
  7. )
  8. from django.utils.translation import gettext_lazy as _
  9. from .mixins import CheckFieldDefaultMixin
  10. __all__ = ['JSONField']
  11. class JsonAdapter(Json):
  12. """
  13. Customized psycopg2.extras.Json to allow for a custom encoder.
  14. """
  15. def __init__(self, adapted, dumps=None, encoder=None):
  16. self.encoder = encoder
  17. super().__init__(adapted, dumps=dumps)
  18. def dumps(self, obj):
  19. options = {'cls': self.encoder} if self.encoder else {}
  20. return json.dumps(obj, **options)
  21. class JSONField(CheckFieldDefaultMixin, Field):
  22. empty_strings_allowed = False
  23. description = _('A JSON object')
  24. default_error_messages = {
  25. 'invalid': _("Value must be valid JSON."),
  26. }
  27. _default_hint = ('dict', '{}')
  28. def __init__(self, verbose_name=None, name=None, encoder=None, **kwargs):
  29. if encoder and not callable(encoder):
  30. raise ValueError("The encoder parameter must be a callable object.")
  31. self.encoder = encoder
  32. super().__init__(verbose_name, name, **kwargs)
  33. def db_type(self, connection):
  34. return 'jsonb'
  35. def deconstruct(self):
  36. name, path, args, kwargs = super().deconstruct()
  37. if self.encoder is not None:
  38. kwargs['encoder'] = self.encoder
  39. return name, path, args, kwargs
  40. def get_transform(self, name):
  41. transform = super().get_transform(name)
  42. if transform:
  43. return transform
  44. return KeyTransformFactory(name)
  45. def get_prep_value(self, value):
  46. if value is not None:
  47. return JsonAdapter(value, encoder=self.encoder)
  48. return value
  49. def validate(self, value, model_instance):
  50. super().validate(value, model_instance)
  51. options = {'cls': self.encoder} if self.encoder else {}
  52. try:
  53. json.dumps(value, **options)
  54. except TypeError:
  55. raise exceptions.ValidationError(
  56. self.error_messages['invalid'],
  57. code='invalid',
  58. params={'value': value},
  59. )
  60. def value_to_string(self, obj):
  61. return self.value_from_object(obj)
  62. def formfield(self, **kwargs):
  63. return super().formfield(**{
  64. 'form_class': forms.JSONField,
  65. **kwargs,
  66. })
  67. JSONField.register_lookup(lookups.DataContains)
  68. JSONField.register_lookup(lookups.ContainedBy)
  69. JSONField.register_lookup(lookups.HasKey)
  70. JSONField.register_lookup(lookups.HasKeys)
  71. JSONField.register_lookup(lookups.HasAnyKeys)
  72. JSONField.register_lookup(lookups.JSONExact)
  73. class KeyTransform(Transform):
  74. operator = '->'
  75. nested_operator = '#>'
  76. def __init__(self, key_name, *args, **kwargs):
  77. super().__init__(*args, **kwargs)
  78. self.key_name = key_name
  79. def as_sql(self, compiler, connection):
  80. key_transforms = [self.key_name]
  81. previous = self.lhs
  82. while isinstance(previous, KeyTransform):
  83. key_transforms.insert(0, previous.key_name)
  84. previous = previous.lhs
  85. lhs, params = compiler.compile(previous)
  86. if len(key_transforms) > 1:
  87. return "(%s %s %%s)" % (lhs, self.nested_operator), [key_transforms] + params
  88. try:
  89. lookup = int(self.key_name)
  90. except ValueError:
  91. lookup = self.key_name
  92. return '(%s %s %%s)' % (lhs, self.operator), tuple(params) + (lookup,)
  93. class KeyTextTransform(KeyTransform):
  94. operator = '->>'
  95. nested_operator = '#>>'
  96. output_field = TextField()
  97. class KeyTransformTextLookupMixin:
  98. """
  99. Mixin for combining with a lookup expecting a text lhs from a JSONField
  100. key lookup. Make use of the ->> operator instead of casting key values to
  101. text and performing the lookup on the resulting representation.
  102. """
  103. def __init__(self, key_transform, *args, **kwargs):
  104. assert isinstance(key_transform, KeyTransform)
  105. key_text_transform = KeyTextTransform(
  106. key_transform.key_name, *key_transform.source_expressions, **key_transform.extra
  107. )
  108. super().__init__(key_text_transform, *args, **kwargs)
  109. class KeyTransformIExact(KeyTransformTextLookupMixin, builtin_lookups.IExact):
  110. pass
  111. class KeyTransformIContains(KeyTransformTextLookupMixin, builtin_lookups.IContains):
  112. pass
  113. class KeyTransformStartsWith(KeyTransformTextLookupMixin, builtin_lookups.StartsWith):
  114. pass
  115. class KeyTransformIStartsWith(KeyTransformTextLookupMixin, builtin_lookups.IStartsWith):
  116. pass
  117. class KeyTransformEndsWith(KeyTransformTextLookupMixin, builtin_lookups.EndsWith):
  118. pass
  119. class KeyTransformIEndsWith(KeyTransformTextLookupMixin, builtin_lookups.IEndsWith):
  120. pass
  121. class KeyTransformRegex(KeyTransformTextLookupMixin, builtin_lookups.Regex):
  122. pass
  123. class KeyTransformIRegex(KeyTransformTextLookupMixin, builtin_lookups.IRegex):
  124. pass
  125. KeyTransform.register_lookup(KeyTransformIExact)
  126. KeyTransform.register_lookup(KeyTransformIContains)
  127. KeyTransform.register_lookup(KeyTransformStartsWith)
  128. KeyTransform.register_lookup(KeyTransformIStartsWith)
  129. KeyTransform.register_lookup(KeyTransformEndsWith)
  130. KeyTransform.register_lookup(KeyTransformIEndsWith)
  131. KeyTransform.register_lookup(KeyTransformRegex)
  132. KeyTransform.register_lookup(KeyTransformIRegex)
  133. class KeyTransformFactory:
  134. def __init__(self, key_name):
  135. self.key_name = key_name
  136. def __call__(self, *args, **kwargs):
  137. return KeyTransform(self.key_name, *args, **kwargs)