Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_common.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. from __future__ import absolute_import
  2. import socket
  3. from amqp import RecoverableConnectionError
  4. from kombu import common
  5. from kombu.common import (
  6. Broadcast, maybe_declare,
  7. send_reply, collect_replies,
  8. declaration_cached, ignore_errors,
  9. QoS, PREFETCH_COUNT_MAX,
  10. )
  11. from .case import Case, ContextMock, Mock, MockPool, patch
  12. class test_ignore_errors(Case):
  13. def test_ignored(self):
  14. connection = Mock()
  15. connection.channel_errors = (KeyError, )
  16. connection.connection_errors = (KeyError, )
  17. with ignore_errors(connection):
  18. raise KeyError()
  19. def raising():
  20. raise KeyError()
  21. ignore_errors(connection, raising)
  22. connection.channel_errors = connection.connection_errors = \
  23. ()
  24. with self.assertRaises(KeyError):
  25. with ignore_errors(connection):
  26. raise KeyError()
  27. class test_declaration_cached(Case):
  28. def test_when_cached(self):
  29. chan = Mock()
  30. chan.connection.client.declared_entities = ['foo']
  31. self.assertTrue(declaration_cached('foo', chan))
  32. def test_when_not_cached(self):
  33. chan = Mock()
  34. chan.connection.client.declared_entities = ['bar']
  35. self.assertFalse(declaration_cached('foo', chan))
  36. class test_Broadcast(Case):
  37. def test_arguments(self):
  38. q = Broadcast(name='test_Broadcast')
  39. self.assertTrue(q.name.startswith('bcast.'))
  40. self.assertEqual(q.alias, 'test_Broadcast')
  41. self.assertTrue(q.auto_delete)
  42. self.assertEqual(q.exchange.name, 'test_Broadcast')
  43. self.assertEqual(q.exchange.type, 'fanout')
  44. q = Broadcast('test_Broadcast', 'explicit_queue_name')
  45. self.assertEqual(q.name, 'explicit_queue_name')
  46. self.assertEqual(q.exchange.name, 'test_Broadcast')
  47. q2 = q(Mock())
  48. self.assertEqual(q2.name, q.name)
  49. class test_maybe_declare(Case):
  50. def test_cacheable(self):
  51. channel = Mock()
  52. client = channel.connection.client = Mock()
  53. client.declared_entities = set()
  54. entity = Mock()
  55. entity.can_cache_declaration = True
  56. entity.auto_delete = False
  57. entity.is_bound = True
  58. entity.channel = channel
  59. maybe_declare(entity, channel)
  60. self.assertEqual(entity.declare.call_count, 1)
  61. self.assertIn(
  62. hash(entity), channel.connection.client.declared_entities,
  63. )
  64. maybe_declare(entity, channel)
  65. self.assertEqual(entity.declare.call_count, 1)
  66. entity.channel.connection = None
  67. with self.assertRaises(RecoverableConnectionError):
  68. maybe_declare(entity)
  69. def test_binds_entities(self):
  70. channel = Mock()
  71. channel.connection.client.declared_entities = set()
  72. entity = Mock()
  73. entity.can_cache_declaration = True
  74. entity.is_bound = False
  75. entity.bind.return_value = entity
  76. entity.bind.return_value.channel = channel
  77. maybe_declare(entity, channel)
  78. entity.bind.assert_called_with(channel)
  79. def test_with_retry(self):
  80. channel = Mock()
  81. client = channel.connection.client = Mock()
  82. client.declared_entities = set()
  83. entity = Mock()
  84. entity.can_cache_declaration = True
  85. entity.is_bound = True
  86. entity.channel = channel
  87. maybe_declare(entity, channel, retry=True)
  88. self.assertTrue(channel.connection.client.ensure.call_count)
  89. class test_replies(Case):
  90. def test_send_reply(self):
  91. req = Mock()
  92. req.content_type = 'application/json'
  93. req.content_encoding = 'binary'
  94. req.properties = {'reply_to': 'hello',
  95. 'correlation_id': 'world'}
  96. channel = Mock()
  97. exchange = Mock()
  98. exchange.is_bound = True
  99. exchange.channel = channel
  100. producer = Mock()
  101. producer.channel = channel
  102. producer.channel.connection.client.declared_entities = set()
  103. send_reply(exchange, req, {'hello': 'world'}, producer)
  104. self.assertTrue(producer.publish.call_count)
  105. args = producer.publish.call_args
  106. self.assertDictEqual(args[0][0], {'hello': 'world'})
  107. self.assertDictEqual(args[1], {'exchange': exchange,
  108. 'routing_key': 'hello',
  109. 'correlation_id': 'world',
  110. 'serializer': 'json',
  111. 'retry': False,
  112. 'retry_policy': None,
  113. 'content_encoding': 'binary'})
  114. @patch('kombu.common.itermessages')
  115. def test_collect_replies_with_ack(self, itermessages):
  116. conn, channel, queue = Mock(), Mock(), Mock()
  117. body, message = Mock(), Mock()
  118. itermessages.return_value = [(body, message)]
  119. it = collect_replies(conn, channel, queue, no_ack=False)
  120. m = next(it)
  121. self.assertIs(m, body)
  122. itermessages.assert_called_with(conn, channel, queue, no_ack=False)
  123. message.ack.assert_called_with()
  124. with self.assertRaises(StopIteration):
  125. next(it)
  126. channel.after_reply_message_received.assert_called_with(queue.name)
  127. @patch('kombu.common.itermessages')
  128. def test_collect_replies_no_ack(self, itermessages):
  129. conn, channel, queue = Mock(), Mock(), Mock()
  130. body, message = Mock(), Mock()
  131. itermessages.return_value = [(body, message)]
  132. it = collect_replies(conn, channel, queue)
  133. m = next(it)
  134. self.assertIs(m, body)
  135. itermessages.assert_called_with(conn, channel, queue, no_ack=True)
  136. self.assertFalse(message.ack.called)
  137. @patch('kombu.common.itermessages')
  138. def test_collect_replies_no_replies(self, itermessages):
  139. conn, channel, queue = Mock(), Mock(), Mock()
  140. itermessages.return_value = []
  141. it = collect_replies(conn, channel, queue)
  142. with self.assertRaises(StopIteration):
  143. next(it)
  144. self.assertFalse(channel.after_reply_message_received.called)
  145. class test_insured(Case):
  146. @patch('kombu.common.logger')
  147. def test_ensure_errback(self, logger):
  148. common._ensure_errback('foo', 30)
  149. self.assertTrue(logger.error.called)
  150. def test_revive_connection(self):
  151. on_revive = Mock()
  152. channel = Mock()
  153. common.revive_connection(Mock(), channel, on_revive)
  154. on_revive.assert_called_with(channel)
  155. common.revive_connection(Mock(), channel, None)
  156. def get_insured_mocks(self, insured_returns=('works', 'ignored')):
  157. conn = ContextMock()
  158. pool = MockPool(conn)
  159. fun = Mock()
  160. insured = conn.autoretry.return_value = Mock()
  161. insured.return_value = insured_returns
  162. return conn, pool, fun, insured
  163. def test_insured(self):
  164. conn, pool, fun, insured = self.get_insured_mocks()
  165. ret = common.insured(pool, fun, (2, 2), {'foo': 'bar'})
  166. self.assertEqual(ret, 'works')
  167. conn.ensure_connection.assert_called_with(
  168. errback=common._ensure_errback,
  169. )
  170. self.assertTrue(insured.called)
  171. i_args, i_kwargs = insured.call_args
  172. self.assertTupleEqual(i_args, (2, 2))
  173. self.assertDictEqual(i_kwargs, {'foo': 'bar',
  174. 'connection': conn})
  175. self.assertTrue(conn.autoretry.called)
  176. ar_args, ar_kwargs = conn.autoretry.call_args
  177. self.assertTupleEqual(ar_args, (fun, conn.default_channel))
  178. self.assertTrue(ar_kwargs.get('on_revive'))
  179. self.assertTrue(ar_kwargs.get('errback'))
  180. def test_insured_custom_errback(self):
  181. conn, pool, fun, insured = self.get_insured_mocks()
  182. custom_errback = Mock()
  183. common.insured(pool, fun, (2, 2), {'foo': 'bar'},
  184. errback=custom_errback)
  185. conn.ensure_connection.assert_called_with(errback=custom_errback)
  186. class MockConsumer(object):
  187. consumers = set()
  188. def __init__(self, channel, queues=None, callbacks=None, **kwargs):
  189. self.channel = channel
  190. self.queues = queues
  191. self.callbacks = callbacks
  192. def __enter__(self):
  193. self.consumers.add(self)
  194. return self
  195. def __exit__(self, *exc_info):
  196. self.consumers.discard(self)
  197. class test_itermessages(Case):
  198. class MockConnection(object):
  199. should_raise_timeout = False
  200. def drain_events(self, **kwargs):
  201. if self.should_raise_timeout:
  202. raise socket.timeout()
  203. for consumer in MockConsumer.consumers:
  204. for callback in consumer.callbacks:
  205. callback('body', 'message')
  206. def test_default(self):
  207. conn = self.MockConnection()
  208. channel = Mock()
  209. channel.connection.client = conn
  210. conn.Consumer = MockConsumer
  211. it = common.itermessages(conn, channel, 'q', limit=1)
  212. ret = next(it)
  213. self.assertTupleEqual(ret, ('body', 'message'))
  214. with self.assertRaises(StopIteration):
  215. next(it)
  216. def test_when_raises_socket_timeout(self):
  217. conn = self.MockConnection()
  218. conn.should_raise_timeout = True
  219. channel = Mock()
  220. channel.connection.client = conn
  221. conn.Consumer = MockConsumer
  222. it = common.itermessages(conn, channel, 'q', limit=1)
  223. with self.assertRaises(StopIteration):
  224. next(it)
  225. @patch('kombu.common.deque')
  226. def test_when_raises_IndexError(self, deque):
  227. deque_instance = deque.return_value = Mock()
  228. deque_instance.popleft.side_effect = IndexError()
  229. conn = self.MockConnection()
  230. channel = Mock()
  231. conn.Consumer = MockConsumer
  232. it = common.itermessages(conn, channel, 'q', limit=1)
  233. with self.assertRaises(StopIteration):
  234. next(it)
  235. class test_QoS(Case):
  236. class _QoS(QoS):
  237. def __init__(self, value):
  238. self.value = value
  239. QoS.__init__(self, None, value)
  240. def set(self, value):
  241. return value
  242. def test_qos_exceeds_16bit(self):
  243. with patch('kombu.common.logger') as logger:
  244. callback = Mock()
  245. qos = QoS(callback, 10)
  246. qos.prev = 100
  247. # cannot use 2 ** 32 because of a bug on OSX Py2.5:
  248. # https://jira.mongodb.org/browse/PYTHON-389
  249. qos.set(4294967296)
  250. self.assertTrue(logger.warn.called)
  251. callback.assert_called_with(prefetch_count=0)
  252. def test_qos_increment_decrement(self):
  253. qos = self._QoS(10)
  254. self.assertEqual(qos.increment_eventually(), 11)
  255. self.assertEqual(qos.increment_eventually(3), 14)
  256. self.assertEqual(qos.increment_eventually(-30), 14)
  257. self.assertEqual(qos.decrement_eventually(7), 7)
  258. self.assertEqual(qos.decrement_eventually(), 6)
  259. def test_qos_disabled_increment_decrement(self):
  260. qos = self._QoS(0)
  261. self.assertEqual(qos.increment_eventually(), 0)
  262. self.assertEqual(qos.increment_eventually(3), 0)
  263. self.assertEqual(qos.increment_eventually(-30), 0)
  264. self.assertEqual(qos.decrement_eventually(7), 0)
  265. self.assertEqual(qos.decrement_eventually(), 0)
  266. self.assertEqual(qos.decrement_eventually(10), 0)
  267. def test_qos_thread_safe(self):
  268. qos = self._QoS(10)
  269. def add():
  270. for i in range(1000):
  271. qos.increment_eventually()
  272. def sub():
  273. for i in range(1000):
  274. qos.decrement_eventually()
  275. def threaded(funs):
  276. from threading import Thread
  277. threads = [Thread(target=fun) for fun in funs]
  278. for thread in threads:
  279. thread.start()
  280. for thread in threads:
  281. thread.join()
  282. threaded([add, add])
  283. self.assertEqual(qos.value, 2010)
  284. qos.value = 1000
  285. threaded([add, sub]) # n = 2
  286. self.assertEqual(qos.value, 1000)
  287. def test_exceeds_short(self):
  288. qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
  289. qos.update()
  290. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  291. qos.increment_eventually()
  292. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  293. qos.increment_eventually()
  294. self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
  295. qos.decrement_eventually()
  296. self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
  297. qos.decrement_eventually()
  298. self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
  299. def test_consumer_increment_decrement(self):
  300. mconsumer = Mock()
  301. qos = QoS(mconsumer.qos, 10)
  302. qos.update()
  303. self.assertEqual(qos.value, 10)
  304. mconsumer.qos.assert_called_with(prefetch_count=10)
  305. qos.decrement_eventually()
  306. qos.update()
  307. self.assertEqual(qos.value, 9)
  308. mconsumer.qos.assert_called_with(prefetch_count=9)
  309. qos.decrement_eventually()
  310. self.assertEqual(qos.value, 8)
  311. mconsumer.qos.assert_called_with(prefetch_count=9)
  312. self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)
  313. # Does not decrement 0 value
  314. qos.value = 0
  315. qos.decrement_eventually()
  316. self.assertEqual(qos.value, 0)
  317. qos.increment_eventually()
  318. self.assertEqual(qos.value, 0)
  319. def test_consumer_decrement_eventually(self):
  320. mconsumer = Mock()
  321. qos = QoS(mconsumer.qos, 10)
  322. qos.decrement_eventually()
  323. self.assertEqual(qos.value, 9)
  324. qos.value = 0
  325. qos.decrement_eventually()
  326. self.assertEqual(qos.value, 0)
  327. def test_set(self):
  328. mconsumer = Mock()
  329. qos = QoS(mconsumer.qos, 10)
  330. qos.set(12)
  331. self.assertEqual(qos.prev, 12)
  332. qos.set(qos.prev)