123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import functools
-
- from asgiref.sync import async_to_sync
-
- from . import DEFAULT_CHANNEL_LAYER
- from .db import database_sync_to_async
- from .exceptions import StopConsumer
- from .layers import get_channel_layer
- from .utils import await_many_dispatch
-
-
- def get_handler_name(message):
- """
- Looks at a message, checks it has a sensible type, and returns the
- handler name for that type.
- """
- # Check message looks OK
- if "type" not in message:
- raise ValueError("Incoming message has no 'type' attribute")
- if message["type"].startswith("_"):
- raise ValueError("Malformed type in message (leading underscore)")
- # Extract type and replace . with _
- return message["type"].replace(".", "_")
-
-
- class AsyncConsumer:
- """
- Base consumer class. Implements the ASGI application spec, and adds on
- channel layer management and routing of events to named methods based
- on their type.
- """
-
- _sync = False
- channel_layer_alias = DEFAULT_CHANNEL_LAYER
-
- def __init__(self, scope):
- self.scope = scope
-
- async def __call__(self, receive, send):
- """
- Dispatches incoming messages to type-based handlers asynchronously.
- """
- # Initialize channel layer
- self.channel_layer = get_channel_layer(self.channel_layer_alias)
- if self.channel_layer is not None:
- self.channel_name = await self.channel_layer.new_channel()
- self.channel_receive = functools.partial(
- self.channel_layer.receive, self.channel_name
- )
- # Store send function
- if self._sync:
- self.base_send = async_to_sync(send)
- else:
- self.base_send = send
- # Pass messages in from channel layer or client to dispatch method
- try:
- if self.channel_layer is not None:
- await await_many_dispatch(
- [receive, self.channel_receive], self.dispatch
- )
- else:
- await await_many_dispatch([receive], self.dispatch)
- except StopConsumer:
- # Exit cleanly
- pass
-
- async def dispatch(self, message):
- """
- Works out what to do with a message.
- """
- handler = getattr(self, get_handler_name(message), None)
- if handler:
- await handler(message)
- else:
- raise ValueError("No handler for message type %s" % message["type"])
-
- async def send(self, message):
- """
- Overrideable/callable-by-subclasses send method.
- """
- await self.base_send(message)
-
-
- class SyncConsumer(AsyncConsumer):
- """
- Synchronous version of the consumer, which is what we write most of the
- generic consumers against (for now). Calls handlers in a threadpool and
- uses CallBouncer to get the send method out to the main event loop.
-
- It would have been possible to have "mixed" consumers and auto-detect
- if a handler was awaitable or not, but that would have made the API
- for user-called methods very confusing as there'd be two types of each.
- """
-
- _sync = True
-
- @database_sync_to_async
- def dispatch(self, message):
- """
- Dispatches incoming messages to type-based handlers asynchronously.
- """
- # Get and execute the handler
- handler = getattr(self, get_handler_name(message), None)
- if handler:
- handler(message)
- else:
- raise ValueError("No handler for message type %s" % message["type"])
-
- def send(self, message):
- """
- Overrideable/callable-by-subclasses send method.
- """
- self.base_send(message)
|