123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338 |
- """
- kombu.transport.mongodb
- =======================
-
- MongoDB transport.
-
- :copyright: (c) 2010 - 2013 by Flavio Percoco Premoli.
- :license: BSD, see LICENSE for more details.
-
- """
- from __future__ import absolute_import
-
- import pymongo
-
- from pymongo import errors
- from anyjson import loads, dumps
- from pymongo import MongoClient, uri_parser
-
- from kombu.five import Empty
- from kombu.syn import _detect_environment
- from kombu.utils.encoding import bytes_to_str
-
- from . import virtual
-
- try:
- from pymongo.cursor import CursorType
- except ImportError:
- class CursorType(object): # noqa
- pass
-
- DEFAULT_HOST = '127.0.0.1'
- DEFAULT_PORT = 27017
-
- DEFAULT_MESSAGES_COLLECTION = 'messages'
- DEFAULT_ROUTING_COLLECTION = 'messages.routing'
- DEFAULT_BROADCAST_COLLECTION = 'messages.broadcast'
-
-
- class BroadcastCursor(object):
- """Cursor for broadcast queues."""
-
- def __init__(self, cursor):
- self._cursor = cursor
-
- self.purge(rewind=False)
-
- def get_size(self):
- return self._cursor.count() - self._offset
-
- def close(self):
- self._cursor.close()
-
- def purge(self, rewind=True):
- if rewind:
- self._cursor.rewind()
-
- # Fast forward the cursor past old events
- self._offset = self._cursor.count()
- self._cursor = self._cursor.skip(self._offset)
-
- def __iter__(self):
- return self
-
- def __next__(self):
- while True:
- try:
- msg = next(self._cursor)
- except pymongo.errors.OperationFailure as exc:
- # In some cases tailed cursor can become invalid
- # and have to be reinitalized
- if 'not valid at server' in exc.message:
- self.purge()
-
- continue
-
- raise
- else:
- break
-
- self._offset += 1
-
- return msg
- next = __next__
-
-
- class Channel(virtual.Channel):
- _client = None
- supports_fanout = True
- _fanout_queues = {}
-
- def __init__(self, *vargs, **kwargs):
- super(Channel, self).__init__(*vargs, **kwargs)
-
- self._broadcast_cursors = {}
-
- # Evaluate connection
- self._create_client()
-
- def _new_queue(self, queue, **kwargs):
- pass
-
- def _get(self, queue):
- if queue in self._fanout_queues:
- try:
- msg = next(self.get_broadcast_cursor(queue))
- except StopIteration:
- msg = None
- else:
- msg = self.get_messages().find_and_modify(
- query={'queue': queue},
- sort={'_id': pymongo.ASCENDING},
- remove=True,
- )
-
- if msg is None:
- raise Empty()
-
- return loads(bytes_to_str(msg['payload']))
-
- def _size(self, queue):
- if queue in self._fanout_queues:
- return self.get_broadcast_cursor(queue).get_size()
-
- return self.get_messages().find({'queue': queue}).count()
-
- def _put(self, queue, message, **kwargs):
- self.get_messages().insert({'payload': dumps(message),
- 'queue': queue})
-
- def _purge(self, queue):
- size = self._size(queue)
-
- if queue in self._fanout_queues:
- self.get_broadcaset_cursor(queue).purge()
- else:
- self.get_messages().remove({'queue': queue})
-
- return size
-
- def _parse_uri(self, scheme='mongodb://'):
- # See mongodb uri documentation:
- # http://docs.mongodb.org/manual/reference/connection-string/
- client = self.connection.client
- hostname = client.hostname
-
- if not hostname.startswith(scheme):
- hostname = scheme + hostname
-
- if not hostname[len(scheme):]:
- hostname += DEFAULT_HOST
-
- if client.userid and '@' not in hostname:
- head, tail = hostname.split('://')
-
- credentials = client.userid
- if client.password:
- credentials += ':' + client.password
-
- hostname = head + '://' + credentials + '@' + tail
-
- port = client.port if client.port is not None else DEFAULT_PORT
-
- parsed = uri_parser.parse_uri(hostname, port)
-
- dbname = parsed['database'] or client.virtual_host
-
- if dbname in ('/', None):
- dbname = 'kombu_default'
-
- options = {
- 'auto_start_request': True,
- 'ssl': client.ssl,
- 'connectTimeoutMS': (int(client.connect_timeout * 1000)
- if client.connect_timeout else None),
- }
- options.update(client.transport_options)
- options.update(parsed['options'])
-
- return hostname, dbname, options
-
- def _prepare_client_options(self, options):
- if pymongo.version_tuple >= (3, ):
- options.pop('auto_start_request', None)
- return options
-
- def _open(self, scheme='mongodb://'):
- hostname, dbname, options = self._parse_uri(scheme=scheme)
-
- conf = self._prepare_client_options(options)
- conf['host'] = hostname
-
- env = _detect_environment()
- if env == 'gevent':
- from gevent import monkey
- monkey.patch_all()
- elif env == 'eventlet':
- from eventlet import monkey_patch
- monkey_patch()
-
- mongoconn = MongoClient(**conf)
- database = mongoconn[dbname]
-
- version = mongoconn.server_info()['version']
- if tuple(map(int, version.split('.')[:2])) < (1, 3):
- raise NotImplementedError(
- 'Kombu requires MongoDB version 1.3+ (server is {0})'.format(
- version))
-
- self._create_broadcast(database, options)
-
- self._client = database
-
- def _create_broadcast(self, database, options):
- '''Create capped collection for broadcast messages.'''
- if DEFAULT_BROADCAST_COLLECTION in database.collection_names():
- return
-
- capsize = options.get('capped_queue_size') or 100000
- database.create_collection(DEFAULT_BROADCAST_COLLECTION,
- size=capsize, capped=True)
-
- def _ensure_indexes(self):
- '''Ensure indexes on collections.'''
- self.get_messages().ensure_index(
- [('queue', 1), ('_id', 1)], background=True,
- )
- self.get_broadcast().ensure_index([('queue', 1)])
- self.get_routing().ensure_index([('queue', 1), ('exchange', 1)])
-
- # TODO Store a more complete exchange metatable in the routing collection
- def get_table(self, exchange):
- """Get table of bindings for ``exchange``."""
- localRoutes = frozenset(self.state.exchanges[exchange]['table'])
- brokerRoutes = self.get_messages().routing.find(
- {'exchange': exchange}
- )
-
- return localRoutes | frozenset((r['routing_key'],
- r['pattern'],
- r['queue']) for r in brokerRoutes)
-
- def _put_fanout(self, exchange, message, routing_key, **kwargs):
- """Deliver fanout message."""
- self.get_broadcast().insert({'payload': dumps(message),
- 'queue': exchange})
-
- def _queue_bind(self, exchange, routing_key, pattern, queue):
- if self.typeof(exchange).type == 'fanout':
- self.create_broadcast_cursor(exchange, routing_key, pattern, queue)
- self._fanout_queues[queue] = exchange
-
- meta = {'exchange': exchange,
- 'queue': queue,
- 'routing_key': routing_key,
- 'pattern': pattern}
- self.get_routing().update(meta, meta, upsert=True)
-
- def queue_delete(self, queue, **kwargs):
- self.get_routing().remove({'queue': queue})
-
- super(Channel, self).queue_delete(queue, **kwargs)
-
- if queue in self._fanout_queues:
- try:
- cursor = self._broadcast_cursors.pop(queue)
- except KeyError:
- pass
- else:
- cursor.close()
-
- self._fanout_queues.pop(queue)
-
- def _create_client(self):
- self._open()
- self._ensure_indexes()
-
- @property
- def client(self):
- if self._client is None:
- self._create_client()
- return self._client
-
- def get_messages(self):
- return self.client[DEFAULT_MESSAGES_COLLECTION]
-
- def get_routing(self):
- return self.client[DEFAULT_ROUTING_COLLECTION]
-
- def get_broadcast(self):
- return self.client[DEFAULT_BROADCAST_COLLECTION]
-
- def get_broadcast_cursor(self, queue):
- try:
- return self._broadcast_cursors[queue]
- except KeyError:
- # Cursor may be absent when Channel created more than once.
- # _fanout_queues is a class-level mutable attribute so it's
- # shared over all Channel instances.
- return self.create_broadcast_cursor(
- self._fanout_queues[queue], None, None, queue,
- )
-
- def create_broadcast_cursor(self, exchange, routing_key, pattern, queue):
- if pymongo.version_tuple >= (3, ):
- query = dict(filter={'queue': exchange},
- sort=[('$natural', 1)],
- cursor_type=CursorType.TAILABLE
- )
- else:
- query = dict(query={'queue': exchange},
- sort=[('$natural', 1)],
- tailable=True
- )
-
- cursor = self.get_broadcast().find(**query)
- ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor)
- return ret
-
-
- class Transport(virtual.Transport):
- Channel = Channel
-
- can_parse_url = True
- polling_interval = 1
- default_port = DEFAULT_PORT
- connection_errors = (
- virtual.Transport.connection_errors + (errors.ConnectionFailure, )
- )
- channel_errors = (
- virtual.Transport.channel_errors + (
- errors.ConnectionFailure,
- errors.OperationFailure)
- )
- driver_type = 'mongodb'
- driver_name = 'pymongo'
-
- def driver_version(self):
- return pymongo.version
|