You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import json
  2. import threading
  3. import importlib
  4. import six
  5. from django.conf import settings
  6. from django.http import HttpResponse
  7. from django.core.serializers.json import DjangoJSONEncoder
  8. from gripcontrol import HttpStreamFormat
  9. try:
  10. from urllib import quote
  11. except ImportError:
  12. from urllib.parse import quote
  13. tlocal = threading.local()
  14. def have_channels():
  15. try:
  16. from channels.generic.http import AsyncHttpConsumer
  17. return True
  18. except ImportError:
  19. return False
  20. # return dict of (channel, last-id)
  21. def parse_last_event_id(s):
  22. out = {}
  23. parts = s.split(',')
  24. for part in parts:
  25. channel, last_id = part.split(':')
  26. out[channel] = last_id
  27. return out
  28. def make_id(ids):
  29. id_parts = []
  30. for channel, id in six.iteritems(ids):
  31. enc_channel = quote(channel)
  32. id_parts.append('%s:%s' % (enc_channel, id))
  33. return ','.join(id_parts)
  34. def build_id_escape(s):
  35. out = ''
  36. for c in s:
  37. if c == '%':
  38. out += '%%'
  39. else:
  40. out += c
  41. return out
  42. def sse_encode_event(event_type, data, event_id=None, escape=False):
  43. data_str = json.dumps(data, cls=DjangoJSONEncoder)
  44. if escape:
  45. event_type = build_id_escape(event_type)
  46. data_str = build_id_escape(data_str)
  47. out = 'event: %s\n' % event_type
  48. if event_id:
  49. out += 'id: %s\n' % event_id
  50. out += 'data: %s\n\n' % data_str
  51. return out
  52. def sse_error_response(condition, text, extra=None):
  53. if extra is None:
  54. extra = {}
  55. data = {'condition': condition, 'text': text}
  56. for k, v in six.iteritems(extra):
  57. data[k] = v
  58. body = sse_encode_event('stream-error', data, event_id='error')
  59. return HttpResponse(body, content_type='text/event-stream')
  60. def publish_event(channel, event_type, data, pub_id, pub_prev_id,
  61. skip_user_ids=None):
  62. from django_grip import publish
  63. if skip_user_ids is None:
  64. skip_user_ids = []
  65. content_filters = []
  66. if pub_id:
  67. event_id = '%I'
  68. content_filters.append('build-id')
  69. else:
  70. event_id = None
  71. content = sse_encode_event(event_type, data, event_id=event_id, escape=bool(pub_id))
  72. meta = {}
  73. if skip_user_ids:
  74. meta['skip_users'] = ','.join(skip_user_ids)
  75. publish(
  76. 'events-%s' % quote(channel),
  77. HttpStreamFormat(content, content_filters=content_filters),
  78. id=pub_id,
  79. prev_id=pub_prev_id,
  80. meta=meta)
  81. def publish_kick(user_id, channel):
  82. from django_grip import publish
  83. msg = 'Permission denied to channels: %s' % channel
  84. data = {'condition': 'forbidden', 'text': msg, 'channels': [channel]}
  85. content = sse_encode_event('stream-error', data, event_id='error')
  86. meta = {'require_sub': 'events-%s' % channel}
  87. publish(
  88. 'user-%s' % user_id,
  89. HttpStreamFormat(content),
  90. id='kick-1',
  91. meta=meta)
  92. publish(
  93. 'user-%s' % user_id,
  94. HttpStreamFormat(close=True),
  95. id='kick-2',
  96. prev_id='kick-1',
  97. meta=meta)
  98. def load_class(name):
  99. at = name.rfind('.')
  100. if at == -1:
  101. raise ValueError('class name contains no \'.\'')
  102. module_name = name[0:at]
  103. class_name = name[at + 1:]
  104. return getattr(importlib.import_module(module_name), class_name)()
  105. # load and keep in thread local storage
  106. def get_class(name):
  107. if not hasattr(tlocal, 'loaded'):
  108. tlocal.loaded = {}
  109. c = tlocal.loaded.get(name)
  110. if c is None:
  111. c = load_class(name)
  112. tlocal.loaded[name] = c
  113. return c
  114. def get_class_from_setting(setting_name, default=None):
  115. if hasattr(settings, setting_name):
  116. return get_class(getattr(settings, setting_name))
  117. elif default:
  118. return get_class(default)
  119. else:
  120. return None
  121. def get_storage():
  122. return get_class_from_setting('EVENTSTREAM_STORAGE_CLASS')
  123. def get_channelmanager():
  124. return get_class_from_setting(
  125. 'EVENTSTREAM_CHANNELMANAGER_CLASS',
  126. 'django_eventstream.channelmanager.DefaultChannelManager')
  127. def add_default_headers(headers):
  128. headers['Cache-Control'] = 'no-cache'
  129. headers['X-Accel-Buffering'] = 'no'
  130. augment_cors_headers(headers)
  131. def augment_cors_headers(headers):
  132. cors_origin = ''
  133. if hasattr(settings, 'EVENTSTREAM_ALLOW_ORIGIN'):
  134. cors_origin = settings.EVENTSTREAM_ALLOW_ORIGIN
  135. if cors_origin:
  136. headers['Access-Control-Allow-Origin'] = cors_origin
  137. allow_credentials = False
  138. if hasattr(settings, 'EVENTSTREAM_ALLOW_CREDENTIALS'):
  139. allow_credentials = settings.EVENTSTREAM_ALLOW_CREDENTIALS
  140. if allow_credentials:
  141. headers['Access-Control-Allow-Credentials'] = 'true'