|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363 |
- import asyncio
- import fnmatch
- import random
- import re
- import string
- import time
- from copy import deepcopy
-
- from django.conf import settings
- from django.core.signals import setting_changed
- from django.utils.module_loading import import_string
-
- from channels import DEFAULT_CHANNEL_LAYER
-
- from .exceptions import ChannelFull, InvalidChannelLayerError
-
-
- class ChannelLayerManager:
- """
- Takes a settings dictionary of backends and initialises them on request.
- """
-
- def __init__(self):
- self.backends = {}
- setting_changed.connect(self._reset_backends)
-
- def _reset_backends(self, setting, **kwargs):
- """
- Removes cached channel layers when the CHANNEL_LAYERS setting changes.
- """
- if setting == "CHANNEL_LAYERS":
- self.backends = {}
-
- @property
- def configs(self):
- # Lazy load settings so we can be imported
- return getattr(settings, "CHANNEL_LAYERS", {})
-
- def make_backend(self, name):
- """
- Instantiate channel layer.
- """
- config = self.configs[name].get("CONFIG", {})
- return self._make_backend(name, config)
-
- def make_test_backend(self, name):
- """
- Instantiate channel layer using its test config.
- """
- try:
- config = self.configs[name]["TEST_CONFIG"]
- except KeyError:
- raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
- return self._make_backend(name, config)
-
- def _make_backend(self, name, config):
- # Check for old format config
- if "ROUTING" in self.configs[name]:
- raise InvalidChannelLayerError(
- "ROUTING key found for %s - this is no longer needed in Channels 2."
- % name
- )
- # Load the backend class
- try:
- backend_class = import_string(self.configs[name]["BACKEND"])
- except KeyError:
- raise InvalidChannelLayerError("No BACKEND specified for %s" % name)
- except ImportError:
- raise InvalidChannelLayerError(
- "Cannot import BACKEND %r specified for %s"
- % (self.configs[name]["BACKEND"], name)
- )
- # Initialise and pass config
- return backend_class(**config)
-
- def __getitem__(self, key):
- if key not in self.backends:
- self.backends[key] = self.make_backend(key)
- return self.backends[key]
-
- def __contains__(self, key):
- return key in self.configs
-
- def set(self, key, layer):
- """
- Sets an alias to point to a new ChannelLayerWrapper instance, and
- returns the old one that it replaced. Useful for swapping out the
- backend during tests.
- """
- old = self.backends.get(key, None)
- self.backends[key] = layer
- return old
-
-
- class BaseChannelLayer:
- """
- Base channel layer class that others can inherit from, with useful
- common functionality.
- """
-
- MAX_NAME_LENGTH = 100
-
- def __init__(self, expiry=60, capacity=100, channel_capacity=None):
- self.expiry = expiry
- self.capacity = capacity
- self.channel_capacity = channel_capacity or {}
-
- def compile_capacities(self, channel_capacity):
- """
- Takes an input channel_capacity dict and returns the compiled list
- of regexes that get_capacity will look for as self.channel_capacity
- """
- result = []
- for pattern, value in channel_capacity.items():
- # If they passed in a precompiled regex, leave it, else interpret
- # it as a glob.
- if hasattr(pattern, "match"):
- result.append((pattern, value))
- else:
- result.append((re.compile(fnmatch.translate(pattern)), value))
- return result
-
- def get_capacity(self, channel):
- """
- Gets the correct capacity for the given channel; either the default,
- or a matching result from channel_capacity. Returns the first matching
- result; if you want to control the order of matches, use an ordered dict
- as input.
- """
- for pattern, capacity in self.channel_capacity:
- if pattern.match(channel):
- return capacity
- return self.capacity
-
- def match_type_and_length(self, name):
- if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
- return True
- return False
-
- # Name validation functions
-
- channel_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+(\![\d\w\-_.]*)?$")
- group_name_regex = re.compile(r"^[a-zA-Z\d\-_.]+$")
- invalid_name_error = (
- "{} name must be a valid unicode string "
- + "with length < {} ".format(MAX_NAME_LENGTH)
- + "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
- + "not {}"
- )
-
- def valid_channel_name(self, name, receive=False):
- if self.match_type_and_length(name):
- if bool(self.channel_name_regex.match(name)):
- # Check cases for special channels
- if "!" in name and not name.endswith("!") and receive:
- raise TypeError(
- "Specific channel names in receive() must end at the !"
- )
- return True
- raise TypeError(self.invalid_name_error.format("Channel", name))
-
- def valid_group_name(self, name):
- if self.match_type_and_length(name):
- if bool(self.group_name_regex.match(name)):
- return True
- raise TypeError(self.invalid_name_error.format("Group", name))
-
- def valid_channel_names(self, names, receive=False):
- _non_empty_list = True if names else False
- _names_type = isinstance(names, list)
- assert _non_empty_list and _names_type, "names must be a non-empty list"
-
- assert all(
- self.valid_channel_name(channel, receive=receive) for channel in names
- )
- return True
-
- def non_local_name(self, name):
- """
- Given a channel name, returns the "non-local" part. If the channel name
- is a process-specific channel (contains !) this means the part up to
- and including the !; if it is anything else, this means the full name.
- """
- if "!" in name:
- return name[: name.find("!") + 1]
- else:
- return name
-
-
- class InMemoryChannelLayer(BaseChannelLayer):
- """
- In-memory channel layer implementation
- """
-
- def __init__(
- self,
- expiry=60,
- group_expiry=86400,
- capacity=100,
- channel_capacity=None,
- **kwargs
- ):
- super().__init__(
- expiry=expiry,
- capacity=capacity,
- channel_capacity=channel_capacity,
- **kwargs
- )
- self.channels = {}
- self.groups = {}
- self.group_expiry = group_expiry
-
- # Channel layer API
-
- extensions = ["groups", "flush"]
-
- async def send(self, channel, message):
- """
- Send a message onto a (general or specific) channel.
- """
- # Typecheck
- assert isinstance(message, dict), "message is not a dict"
- assert self.valid_channel_name(channel), "Channel name not valid"
- # If it's a process-local channel, strip off local part and stick full
- # name in message
- assert "__asgi_channel__" not in message
-
- queue = self.channels.setdefault(channel, asyncio.Queue())
- # Are we full
- if queue.qsize() >= self.capacity:
- raise ChannelFull(channel)
-
- # Add message
- await queue.put((time.time() + self.expiry, deepcopy(message)))
-
- async def receive(self, channel):
- """
- Receive the first message that arrives on the channel.
- If more than one coroutine waits on the same channel, a random one
- of the waiting coroutines will get the result.
- """
- assert self.valid_channel_name(channel)
- self._clean_expired()
-
- queue = self.channels.setdefault(channel, asyncio.Queue())
-
- # Do a plain direct receive
- try:
- _, message = await queue.get()
- finally:
- if queue.empty():
- del self.channels[channel]
-
- return message
-
- async def new_channel(self, prefix="specific."):
- """
- Returns a new channel name that can be used by something in our
- process as a specific channel.
- """
- return "%s.inmemory!%s" % (
- prefix,
- "".join(random.choice(string.ascii_letters) for i in range(12)),
- )
-
- # Expire cleanup
-
- def _clean_expired(self):
- """
- Goes through all messages and groups and removes those that are expired.
- Any channel with an expired message is removed from all groups.
- """
- # Channel cleanup
- for channel, queue in list(self.channels.items()):
- # See if it's expired
- while not queue.empty() and queue._queue[0][0] < time.time():
- queue.get_nowait()
- # Any removal prompts group discard
- self._remove_from_groups(channel)
- # Is the channel now empty and needs deleting?
- if queue.empty():
- del self.channels[channel]
-
- # Group Expiration
- timeout = int(time.time()) - self.group_expiry
- for group in self.groups:
- for channel in list(self.groups.get(group, set())):
- # If join time is older than group_expiry end the group membership
- if (
- self.groups[group][channel]
- and int(self.groups[group][channel]) < timeout
- ):
- # Delete from group
- del self.groups[group][channel]
-
- # Flush extension
-
- async def flush(self):
- self.channels = {}
- self.groups = {}
-
- async def close(self):
- # Nothing to go
- pass
-
- def _remove_from_groups(self, channel):
- """
- Removes a channel from all groups. Used when a message on it expires.
- """
- for channels in self.groups.values():
- if channel in channels:
- del channels[channel]
-
- # Groups extension
-
- async def group_add(self, group, channel):
- """
- Adds the channel name to a group.
- """
- # Check the inputs
- assert self.valid_group_name(group), "Group name not valid"
- assert self.valid_channel_name(channel), "Channel name not valid"
- # Add to group dict
- self.groups.setdefault(group, {})
- self.groups[group][channel] = time.time()
-
- async def group_discard(self, group, channel):
- # Both should be text and valid
- assert self.valid_channel_name(channel), "Invalid channel name"
- assert self.valid_group_name(group), "Invalid group name"
- # Remove from group set
- if group in self.groups:
- if channel in self.groups[group]:
- del self.groups[group][channel]
- if not self.groups[group]:
- del self.groups[group]
-
- async def group_send(self, group, message):
- # Check types
- assert isinstance(message, dict), "Message is not a dict"
- assert self.valid_group_name(group), "Invalid group name"
- # Run clean
- self._clean_expired()
- # Send to each channel
- for channel in self.groups.get(group, set()):
- try:
- await self.send(channel, message)
- except ChannelFull:
- pass
-
-
- def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
- """
- Returns a channel layer by alias, or None if it is not configured.
- """
- try:
- return channel_layers[alias]
- except KeyError:
- return None
-
-
- # Default global instance of the channel layer manager
- channel_layers = ChannelLayerManager()
|