Funktionierender Prototyp des Serious Games zur Vermittlung von Wissen zu Software-Engineering-Arbeitsmodellen.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

layers.py 12KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import asyncio
  2. import fnmatch
  3. import random
  4. import re
  5. import string
  6. import time
  7. from copy import deepcopy
  8. from django.conf import settings
  9. from django.core.signals import setting_changed
  10. from django.utils.module_loading import import_string
  11. from channels import DEFAULT_CHANNEL_LAYER
  12. from .exceptions import ChannelFull, InvalidChannelLayerError
  13. class ChannelLayerManager:
  14. """
  15. Takes a settings dictionary of backends and initialises them on request.
  16. """
  17. def __init__(self):
  18. self.backends = {}
  19. setting_changed.connect(self._reset_backends)
  20. def _reset_backends(self, setting, **kwargs):
  21. """
  22. Removes cached channel layers when the CHANNEL_LAYERS setting changes.
  23. """
  24. if setting == "CHANNEL_LAYERS":
  25. self.backends = {}
  26. @property
  27. def configs(self):
  28. # Lazy load settings so we can be imported
  29. return getattr(settings, "CHANNEL_LAYERS", {})
  30. def make_backend(self, name):
  31. """
  32. Instantiate channel layer.
  33. """
  34. config = self.configs[name].get("CONFIG", {})
  35. return self._make_backend(name, config)
  36. def make_test_backend(self, name):
  37. """
  38. Instantiate channel layer using its test config.
  39. """
  40. try:
  41. config = self.configs[name]["TEST_CONFIG"]
  42. except KeyError:
  43. raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
  44. return self._make_backend(name, config)
  45. def _make_backend(self, name, config):
  46. # Check for old format config
  47. if "ROUTING" in self.configs[name]:
  48. raise InvalidChannelLayerError(
  49. "ROUTING key found for %s - this is no longer needed in Channels 2."
  50. % name
  51. )
  52. # Load the backend class
  53. try:
  54. backend_class = import_string(self.configs[name]["BACKEND"])
  55. except KeyError:
  56. raise InvalidChannelLayerError("No BACKEND specified for %s" % name)
  57. except ImportError:
  58. raise InvalidChannelLayerError(
  59. "Cannot import BACKEND %r specified for %s"
  60. % (self.configs[name]["BACKEND"], name)
  61. )
  62. # Initialise and pass config
  63. return backend_class(**config)
  64. def __getitem__(self, key):
  65. if key not in self.backends:
  66. self.backends[key] = self.make_backend(key)
  67. return self.backends[key]
  68. def __contains__(self, key):
  69. return key in self.configs
  70. def set(self, key, layer):
  71. """
  72. Sets an alias to point to a new ChannelLayerWrapper instance, and
  73. returns the old one that it replaced. Useful for swapping out the
  74. backend during tests.
  75. """
  76. old = self.backends.get(key, None)
  77. self.backends[key] = layer
  78. return old
  79. class BaseChannelLayer:
  80. """
  81. Base channel layer class that others can inherit from, with useful
  82. common functionality.
  83. """
  84. MAX_NAME_LENGTH = 100
  85. def __init__(self, expiry=60, capacity=100, channel_capacity=None):
  86. self.expiry = expiry
  87. self.capacity = capacity
  88. self.channel_capacity = channel_capacity or {}
  89. def compile_capacities(self, channel_capacity):
  90. """
  91. Takes an input channel_capacity dict and returns the compiled list
  92. of regexes that get_capacity will look for as self.channel_capacity
  93. """
  94. result = []
  95. for pattern, value in channel_capacity.items():
  96. # If they passed in a precompiled regex, leave it, else interpret
  97. # it as a glob.
  98. if hasattr(pattern, "match"):
  99. result.append((pattern, value))
  100. else:
  101. result.append((re.compile(fnmatch.translate(pattern)), value))
  102. return result
  103. def get_capacity(self, channel):
  104. """
  105. Gets the correct capacity for the given channel; either the default,
  106. or a matching result from channel_capacity. Returns the first matching
  107. result; if you want to control the order of matches, use an ordered dict
  108. as input.
  109. """
  110. for pattern, capacity in self.channel_capacity:
  111. if pattern.match(channel):
  112. return capacity
  113. return self.capacity
  114. def match_type_and_length(self, name):
  115. if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
  116. return True
  117. return False
  118. # Name validation functions
  119. channel_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+(\![\d\w\-_.]*)?$")
  120. group_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+$")
  121. invalid_name_error = (
  122. "{} name must be a valid unicode string "
  123. + "with length < {} ".format(MAX_NAME_LENGTH)
  124. + "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
  125. + "not {}"
  126. )
  127. def valid_channel_name(self, name, receive=False):
  128. if self.match_type_and_length(name):
  129. if bool(self.channel_name_regex.match(name)):
  130. # Check cases for special channels
  131. if "!" in name and not name.endswith("!") and receive:
  132. raise TypeError(
  133. "Specific channel names in receive() must end at the !"
  134. )
  135. return True
  136. raise TypeError(self.invalid_name_error.format("Channel", name))
  137. def valid_group_name(self, name):
  138. if self.match_type_and_length(name):
  139. if bool(self.group_name_regex.match(name)):
  140. return True
  141. raise TypeError(self.invalid_name_error.format("Group", name))
  142. def valid_channel_names(self, names, receive=False):
  143. _non_empty_list = True if names else False
  144. _names_type = isinstance(names, list)
  145. assert _non_empty_list and _names_type, "names must be a non-empty list"
  146. assert all(
  147. self.valid_channel_name(channel, receive=receive) for channel in names
  148. )
  149. return True
  150. def non_local_name(self, name):
  151. """
  152. Given a channel name, returns the "non-local" part. If the channel name
  153. is a process-specific channel (contains !) this means the part up to
  154. and including the !; if it is anything else, this means the full name.
  155. """
  156. if "!" in name:
  157. return name[: name.find("!") + 1]
  158. else:
  159. return name
  160. class InMemoryChannelLayer(BaseChannelLayer):
  161. """
  162. In-memory channel layer implementation
  163. """
  164. def __init__(
  165. self,
  166. expiry=60,
  167. group_expiry=86400,
  168. capacity=100,
  169. channel_capacity=None,
  170. **kwargs
  171. ):
  172. super().__init__(
  173. expiry=expiry,
  174. capacity=capacity,
  175. channel_capacity=channel_capacity,
  176. **kwargs
  177. )
  178. self.channels = {}
  179. self.groups = {}
  180. self.group_expiry = group_expiry
  181. # Channel layer API
  182. extensions = ["groups", "flush"]
  183. async def send(self, channel, message):
  184. """
  185. Send a message onto a (general or specific) channel.
  186. """
  187. # Typecheck
  188. assert isinstance(message, dict), "message is not a dict"
  189. assert self.valid_channel_name(channel), "Channel name not valid"
  190. # If it's a process-local channel, strip off local part and stick full
  191. # name in message
  192. assert "__asgi_channel__" not in message
  193. queue = self.channels.setdefault(channel, asyncio.Queue())
  194. # Are we full
  195. if queue.qsize() >= self.capacity:
  196. raise ChannelFull(channel)
  197. # Add message
  198. await queue.put((time.time() + self.expiry, deepcopy(message)))
  199. async def receive(self, channel):
  200. """
  201. Receive the first message that arrives on the channel.
  202. If more than one coroutine waits on the same channel, a random one
  203. of the waiting coroutines will get the result.
  204. """
  205. assert self.valid_channel_name(channel)
  206. self._clean_expired()
  207. queue = self.channels.setdefault(channel, asyncio.Queue())
  208. # Do a plain direct receive
  209. try:
  210. _, message = await queue.get()
  211. finally:
  212. if queue.empty():
  213. del self.channels[channel]
  214. return message
  215. async def new_channel(self, prefix="specific."):
  216. """
  217. Returns a new channel name that can be used by something in our
  218. process as a specific channel.
  219. """
  220. return "%s.inmemory!%s" % (
  221. prefix,
  222. "".join(random.choice(string.ascii_letters) for i in range(12)),
  223. )
  224. # Expire cleanup
  225. def _clean_expired(self):
  226. """
  227. Goes through all messages and groups and removes those that are expired.
  228. Any channel with an expired message is removed from all groups.
  229. """
  230. # Channel cleanup
  231. for channel, queue in list(self.channels.items()):
  232. # See if it's expired
  233. while not queue.empty() and queue._queue[0][0] < time.time():
  234. queue.get_nowait()
  235. # Any removal prompts group discard
  236. self._remove_from_groups(channel)
  237. # Is the channel now empty and needs deleting?
  238. if queue.empty():
  239. del self.channels[channel]
  240. # Group Expiration
  241. timeout = int(time.time()) - self.group_expiry
  242. for group in self.groups:
  243. for channel in list(self.groups.get(group, set())):
  244. # If join time is older than group_expiry end the group membership
  245. if (
  246. self.groups[group][channel]
  247. and int(self.groups[group][channel]) < timeout
  248. ):
  249. # Delete from group
  250. del self.groups[group][channel]
  251. # Flush extension
  252. async def flush(self):
  253. self.channels = {}
  254. self.groups = {}
  255. async def close(self):
  256. # Nothing to go
  257. pass
  258. def _remove_from_groups(self, channel):
  259. """
  260. Removes a channel from all groups. Used when a message on it expires.
  261. """
  262. for channels in self.groups.values():
  263. if channel in channels:
  264. del channels[channel]
  265. # Groups extension
  266. async def group_add(self, group, channel):
  267. """
  268. Adds the channel name to a group.
  269. """
  270. # Check the inputs
  271. assert self.valid_group_name(group), "Group name not valid"
  272. assert self.valid_channel_name(channel), "Channel name not valid"
  273. # Add to group dict
  274. self.groups.setdefault(group, {})
  275. self.groups[group][channel] = time.time()
  276. async def group_discard(self, group, channel):
  277. # Both should be text and valid
  278. assert self.valid_channel_name(channel), "Invalid channel name"
  279. assert self.valid_group_name(group), "Invalid group name"
  280. # Remove from group set
  281. if group in self.groups:
  282. if channel in self.groups[group]:
  283. del self.groups[group][channel]
  284. if not self.groups[group]:
  285. del self.groups[group]
  286. async def group_send(self, group, message):
  287. # Check types
  288. assert isinstance(message, dict), "Message is not a dict"
  289. assert self.valid_group_name(group), "Invalid group name"
  290. # Run clean
  291. self._clean_expired()
  292. # Send to each channel
  293. for channel in self.groups.get(group, set()):
  294. try:
  295. await self.send(channel, message)
  296. except ChannelFull:
  297. pass
  298. def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
  299. """
  300. Returns a channel layer by alias, or None if it is not configured.
  301. """
  302. try:
  303. return channel_layers[alias]
  304. except KeyError:
  305. return None
  306. # Default global instance of the channel layer manager
  307. channel_layers = ChannelLayerManager()