Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mongodb.py 9.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. """
  2. kombu.transport.mongodb
  3. =======================
  4. MongoDB transport.
  5. :copyright: (c) 2010 - 2013 by Flavio Percoco Premoli.
  6. :license: BSD, see LICENSE for more details.
  7. """
  8. from __future__ import absolute_import
  9. import pymongo
  10. from pymongo import errors
  11. from anyjson import loads, dumps
  12. from pymongo import MongoClient, uri_parser
  13. from kombu.five import Empty
  14. from kombu.syn import _detect_environment
  15. from kombu.utils.encoding import bytes_to_str
  16. from . import virtual
  17. try:
  18. from pymongo.cursor import CursorType
  19. except ImportError:
  20. class CursorType(object): # noqa
  21. pass
  22. DEFAULT_HOST = '127.0.0.1'
  23. DEFAULT_PORT = 27017
  24. DEFAULT_MESSAGES_COLLECTION = 'messages'
  25. DEFAULT_ROUTING_COLLECTION = 'messages.routing'
  26. DEFAULT_BROADCAST_COLLECTION = 'messages.broadcast'
  27. class BroadcastCursor(object):
  28. """Cursor for broadcast queues."""
  29. def __init__(self, cursor):
  30. self._cursor = cursor
  31. self.purge(rewind=False)
  32. def get_size(self):
  33. return self._cursor.count() - self._offset
  34. def close(self):
  35. self._cursor.close()
  36. def purge(self, rewind=True):
  37. if rewind:
  38. self._cursor.rewind()
  39. # Fast forward the cursor past old events
  40. self._offset = self._cursor.count()
  41. self._cursor = self._cursor.skip(self._offset)
  42. def __iter__(self):
  43. return self
  44. def __next__(self):
  45. while True:
  46. try:
  47. msg = next(self._cursor)
  48. except pymongo.errors.OperationFailure as exc:
  49. # In some cases tailed cursor can become invalid
  50. # and have to be reinitalized
  51. if 'not valid at server' in exc.message:
  52. self.purge()
  53. continue
  54. raise
  55. else:
  56. break
  57. self._offset += 1
  58. return msg
  59. next = __next__
  60. class Channel(virtual.Channel):
  61. _client = None
  62. supports_fanout = True
  63. _fanout_queues = {}
  64. def __init__(self, *vargs, **kwargs):
  65. super(Channel, self).__init__(*vargs, **kwargs)
  66. self._broadcast_cursors = {}
  67. # Evaluate connection
  68. self._create_client()
  69. def _new_queue(self, queue, **kwargs):
  70. pass
  71. def _get(self, queue):
  72. if queue in self._fanout_queues:
  73. try:
  74. msg = next(self.get_broadcast_cursor(queue))
  75. except StopIteration:
  76. msg = None
  77. else:
  78. msg = self.get_messages().find_and_modify(
  79. query={'queue': queue},
  80. sort={'_id': pymongo.ASCENDING},
  81. remove=True,
  82. )
  83. if msg is None:
  84. raise Empty()
  85. return loads(bytes_to_str(msg['payload']))
  86. def _size(self, queue):
  87. if queue in self._fanout_queues:
  88. return self.get_broadcast_cursor(queue).get_size()
  89. return self.get_messages().find({'queue': queue}).count()
  90. def _put(self, queue, message, **kwargs):
  91. self.get_messages().insert({'payload': dumps(message),
  92. 'queue': queue})
  93. def _purge(self, queue):
  94. size = self._size(queue)
  95. if queue in self._fanout_queues:
  96. self.get_broadcaset_cursor(queue).purge()
  97. else:
  98. self.get_messages().remove({'queue': queue})
  99. return size
  100. def _parse_uri(self, scheme='mongodb://'):
  101. # See mongodb uri documentation:
  102. # http://docs.mongodb.org/manual/reference/connection-string/
  103. client = self.connection.client
  104. hostname = client.hostname
  105. if not hostname.startswith(scheme):
  106. hostname = scheme + hostname
  107. if not hostname[len(scheme):]:
  108. hostname += DEFAULT_HOST
  109. if client.userid and '@' not in hostname:
  110. head, tail = hostname.split('://')
  111. credentials = client.userid
  112. if client.password:
  113. credentials += ':' + client.password
  114. hostname = head + '://' + credentials + '@' + tail
  115. port = client.port if client.port is not None else DEFAULT_PORT
  116. parsed = uri_parser.parse_uri(hostname, port)
  117. dbname = parsed['database'] or client.virtual_host
  118. if dbname in ('/', None):
  119. dbname = 'kombu_default'
  120. options = {
  121. 'auto_start_request': True,
  122. 'ssl': client.ssl,
  123. 'connectTimeoutMS': (int(client.connect_timeout * 1000)
  124. if client.connect_timeout else None),
  125. }
  126. options.update(client.transport_options)
  127. options.update(parsed['options'])
  128. return hostname, dbname, options
  129. def _prepare_client_options(self, options):
  130. if pymongo.version_tuple >= (3, ):
  131. options.pop('auto_start_request', None)
  132. return options
  133. def _open(self, scheme='mongodb://'):
  134. hostname, dbname, options = self._parse_uri(scheme=scheme)
  135. conf = self._prepare_client_options(options)
  136. conf['host'] = hostname
  137. env = _detect_environment()
  138. if env == 'gevent':
  139. from gevent import monkey
  140. monkey.patch_all()
  141. elif env == 'eventlet':
  142. from eventlet import monkey_patch
  143. monkey_patch()
  144. mongoconn = MongoClient(**conf)
  145. database = mongoconn[dbname]
  146. version = mongoconn.server_info()['version']
  147. if tuple(map(int, version.split('.')[:2])) < (1, 3):
  148. raise NotImplementedError(
  149. 'Kombu requires MongoDB version 1.3+ (server is {0})'.format(
  150. version))
  151. self._create_broadcast(database, options)
  152. self._client = database
  153. def _create_broadcast(self, database, options):
  154. '''Create capped collection for broadcast messages.'''
  155. if DEFAULT_BROADCAST_COLLECTION in database.collection_names():
  156. return
  157. capsize = options.get('capped_queue_size') or 100000
  158. database.create_collection(DEFAULT_BROADCAST_COLLECTION,
  159. size=capsize, capped=True)
  160. def _ensure_indexes(self):
  161. '''Ensure indexes on collections.'''
  162. self.get_messages().ensure_index(
  163. [('queue', 1), ('_id', 1)], background=True,
  164. )
  165. self.get_broadcast().ensure_index([('queue', 1)])
  166. self.get_routing().ensure_index([('queue', 1), ('exchange', 1)])
  167. # TODO Store a more complete exchange metatable in the routing collection
  168. def get_table(self, exchange):
  169. """Get table of bindings for ``exchange``."""
  170. localRoutes = frozenset(self.state.exchanges[exchange]['table'])
  171. brokerRoutes = self.get_messages().routing.find(
  172. {'exchange': exchange}
  173. )
  174. return localRoutes | frozenset((r['routing_key'],
  175. r['pattern'],
  176. r['queue']) for r in brokerRoutes)
  177. def _put_fanout(self, exchange, message, routing_key, **kwargs):
  178. """Deliver fanout message."""
  179. self.get_broadcast().insert({'payload': dumps(message),
  180. 'queue': exchange})
  181. def _queue_bind(self, exchange, routing_key, pattern, queue):
  182. if self.typeof(exchange).type == 'fanout':
  183. self.create_broadcast_cursor(exchange, routing_key, pattern, queue)
  184. self._fanout_queues[queue] = exchange
  185. meta = {'exchange': exchange,
  186. 'queue': queue,
  187. 'routing_key': routing_key,
  188. 'pattern': pattern}
  189. self.get_routing().update(meta, meta, upsert=True)
  190. def queue_delete(self, queue, **kwargs):
  191. self.get_routing().remove({'queue': queue})
  192. super(Channel, self).queue_delete(queue, **kwargs)
  193. if queue in self._fanout_queues:
  194. try:
  195. cursor = self._broadcast_cursors.pop(queue)
  196. except KeyError:
  197. pass
  198. else:
  199. cursor.close()
  200. self._fanout_queues.pop(queue)
  201. def _create_client(self):
  202. self._open()
  203. self._ensure_indexes()
  204. @property
  205. def client(self):
  206. if self._client is None:
  207. self._create_client()
  208. return self._client
  209. def get_messages(self):
  210. return self.client[DEFAULT_MESSAGES_COLLECTION]
  211. def get_routing(self):
  212. return self.client[DEFAULT_ROUTING_COLLECTION]
  213. def get_broadcast(self):
  214. return self.client[DEFAULT_BROADCAST_COLLECTION]
  215. def get_broadcast_cursor(self, queue):
  216. try:
  217. return self._broadcast_cursors[queue]
  218. except KeyError:
  219. # Cursor may be absent when Channel created more than once.
  220. # _fanout_queues is a class-level mutable attribute so it's
  221. # shared over all Channel instances.
  222. return self.create_broadcast_cursor(
  223. self._fanout_queues[queue], None, None, queue,
  224. )
  225. def create_broadcast_cursor(self, exchange, routing_key, pattern, queue):
  226. if pymongo.version_tuple >= (3, ):
  227. query = dict(filter={'queue': exchange},
  228. sort=[('$natural', 1)],
  229. cursor_type=CursorType.TAILABLE
  230. )
  231. else:
  232. query = dict(query={'queue': exchange},
  233. sort=[('$natural', 1)],
  234. tailable=True
  235. )
  236. cursor = self.get_broadcast().find(**query)
  237. ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor)
  238. return ret
  239. class Transport(virtual.Transport):
  240. Channel = Channel
  241. can_parse_url = True
  242. polling_interval = 1
  243. default_port = DEFAULT_PORT
  244. connection_errors = (
  245. virtual.Transport.connection_errors + (errors.ConnectionFailure, )
  246. )
  247. channel_errors = (
  248. virtual.Transport.channel_errors + (
  249. errors.ConnectionFailure,
  250. errors.OperationFailure)
  251. )
  252. driver_type = 'mongodb'
  253. driver_name = 'pymongo'
  254. def driver_version(self):
  255. return pymongo.version