Development of an internal social media platform with personalised dashboards for students
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

consumer.py 29KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887
  1. # -*- coding: utf-8 -*-
  2. """
  3. celery.worker.consumer
  4. ~~~~~~~~~~~~~~~~~~~~~~
  5. This module contains the components responsible for consuming messages
  6. from the broker, processing the messages and keeping the broker connections
  7. up and running.
  8. """
  9. from __future__ import absolute_import
  10. import errno
  11. import kombu
  12. import logging
  13. import os
  14. import socket
  15. from collections import defaultdict
  16. from functools import partial
  17. from heapq import heappush
  18. from operator import itemgetter
  19. from time import sleep
  20. from billiard.common import restart_state
  21. from billiard.exceptions import RestartFreqExceeded
  22. from kombu.async.semaphore import DummyLock
  23. from kombu.common import QoS, ignore_errors
  24. from kombu.syn import _detect_environment
  25. from kombu.utils.compat import get_errno
  26. from kombu.utils.encoding import safe_repr, bytes_t
  27. from kombu.utils.limits import TokenBucket
  28. from celery import chain
  29. from celery import bootsteps
  30. from celery.app.trace import build_tracer
  31. from celery.canvas import signature
  32. from celery.exceptions import InvalidTaskError
  33. from celery.five import items, values
  34. from celery.utils.functional import noop
  35. from celery.utils.log import get_logger
  36. from celery.utils.objects import Bunch
  37. from celery.utils.text import truncate
  38. from celery.utils.timeutils import humanize_seconds, rate
  39. from . import heartbeat, loops, pidbox
  40. from .state import task_reserved, maybe_shutdown, revoked, reserved_requests
  41. try:
  42. buffer_t = buffer
  43. except NameError: # pragma: no cover
  44. # Py3 does not have buffer, but we only need isinstance.
  45. class buffer_t(object): # noqa
  46. pass
  47. __all__ = [
  48. 'Consumer', 'Connection', 'Events', 'Heart', 'Control',
  49. 'Tasks', 'Evloop', 'Agent', 'Mingle', 'Gossip', 'dump_body',
  50. ]
  51. CLOSE = bootsteps.CLOSE
  52. logger = get_logger(__name__)
  53. debug, info, warn, error, crit = (logger.debug, logger.info, logger.warning,
  54. logger.error, logger.critical)
  55. CONNECTION_RETRY = """\
  56. consumer: Connection to broker lost. \
  57. Trying to re-establish the connection...\
  58. """
  59. CONNECTION_RETRY_STEP = """\
  60. Trying again {when}...\
  61. """
  62. CONNECTION_ERROR = """\
  63. consumer: Cannot connect to %s: %s.
  64. %s
  65. """
  66. CONNECTION_FAILOVER = """\
  67. Will retry using next failover.\
  68. """
  69. UNKNOWN_FORMAT = """\
  70. Received and deleted unknown message. Wrong destination?!?
  71. The full contents of the message body was: %s
  72. """
  73. #: Error message for when an unregistered task is received.
  74. UNKNOWN_TASK_ERROR = """\
  75. Received unregistered task of type %s.
  76. The message has been ignored and discarded.
  77. Did you remember to import the module containing this task?
  78. Or maybe you are using relative imports?
  79. Please see http://bit.ly/gLye1c for more information.
  80. The full contents of the message body was:
  81. %s
  82. """
  83. #: Error message for when an invalid task message is received.
  84. INVALID_TASK_ERROR = """\
  85. Received invalid task message: %s
  86. The message has been ignored and discarded.
  87. Please ensure your message conforms to the task
  88. message protocol as described here: http://bit.ly/hYj41y
  89. The full contents of the message body was:
  90. %s
  91. """
  92. MESSAGE_DECODE_ERROR = """\
  93. Can't decode message body: %r [type:%r encoding:%r headers:%s]
  94. body: %s
  95. """
  96. MESSAGE_REPORT = """\
  97. body: {0}
  98. {{content_type:{1} content_encoding:{2}
  99. delivery_info:{3} headers={4}}}
  100. """
  101. MINGLE_GET_FIELDS = itemgetter('clock', 'revoked')
  102. def dump_body(m, body):
  103. if isinstance(body, buffer_t):
  104. body = bytes_t(body)
  105. return '{0} ({1}b)'.format(truncate(safe_repr(body), 1024),
  106. len(m.body))
  107. class Consumer(object):
  108. Strategies = dict
  109. #: set when consumer is shutting down.
  110. in_shutdown = False
  111. #: Optional callback called the first time the worker
  112. #: is ready to receive tasks.
  113. init_callback = None
  114. #: The current worker pool instance.
  115. pool = None
  116. #: A timer used for high-priority internal tasks, such
  117. #: as sending heartbeats.
  118. timer = None
  119. restart_count = -1 # first start is the same as a restart
  120. class Blueprint(bootsteps.Blueprint):
  121. name = 'Consumer'
  122. default_steps = [
  123. 'celery.worker.consumer:Connection',
  124. 'celery.worker.consumer:Mingle',
  125. 'celery.worker.consumer:Events',
  126. 'celery.worker.consumer:Gossip',
  127. 'celery.worker.consumer:Heart',
  128. 'celery.worker.consumer:Control',
  129. 'celery.worker.consumer:Tasks',
  130. 'celery.worker.consumer:Evloop',
  131. 'celery.worker.consumer:Agent',
  132. ]
  133. def shutdown(self, parent):
  134. self.send_all(parent, 'shutdown')
  135. def __init__(self, on_task_request,
  136. init_callback=noop, hostname=None,
  137. pool=None, app=None,
  138. timer=None, controller=None, hub=None, amqheartbeat=None,
  139. worker_options=None, disable_rate_limits=False,
  140. initial_prefetch_count=2, prefetch_multiplier=1, **kwargs):
  141. self.app = app
  142. self.controller = controller
  143. self.init_callback = init_callback
  144. self.hostname = hostname or socket.gethostname()
  145. self.pid = os.getpid()
  146. self.pool = pool
  147. self.timer = timer
  148. self.strategies = self.Strategies()
  149. conninfo = self.app.connection()
  150. self.connection_errors = conninfo.connection_errors
  151. self.channel_errors = conninfo.channel_errors
  152. self._restart_state = restart_state(maxR=5, maxT=1)
  153. self._does_info = logger.isEnabledFor(logging.INFO)
  154. self.on_task_request = on_task_request
  155. self.on_task_message = set()
  156. self.amqheartbeat_rate = self.app.conf.BROKER_HEARTBEAT_CHECKRATE
  157. self.disable_rate_limits = disable_rate_limits
  158. self.initial_prefetch_count = initial_prefetch_count
  159. self.prefetch_multiplier = prefetch_multiplier
  160. # this contains a tokenbucket for each task type by name, used for
  161. # rate limits, or None if rate limits are disabled for that task.
  162. self.task_buckets = defaultdict(lambda: None)
  163. self.reset_rate_limits()
  164. self.hub = hub
  165. if self.hub:
  166. self.amqheartbeat = amqheartbeat
  167. if self.amqheartbeat is None:
  168. self.amqheartbeat = self.app.conf.BROKER_HEARTBEAT
  169. else:
  170. self.amqheartbeat = 0
  171. if not hasattr(self, 'loop'):
  172. self.loop = loops.asynloop if hub else loops.synloop
  173. if _detect_environment() == 'gevent':
  174. # there's a gevent bug that causes timeouts to not be reset,
  175. # so if the connection timeout is exceeded once, it can NEVER
  176. # connect again.
  177. self.app.conf.BROKER_CONNECTION_TIMEOUT = None
  178. self.steps = []
  179. self.blueprint = self.Blueprint(
  180. app=self.app, on_close=self.on_close,
  181. )
  182. self.blueprint.apply(self, **dict(worker_options or {}, **kwargs))
  183. def bucket_for_task(self, type):
  184. limit = rate(getattr(type, 'rate_limit', None))
  185. return TokenBucket(limit, capacity=1) if limit else None
  186. def reset_rate_limits(self):
  187. self.task_buckets.update(
  188. (n, self.bucket_for_task(t)) for n, t in items(self.app.tasks)
  189. )
  190. def _update_prefetch_count(self, index=0):
  191. """Update prefetch count after pool/shrink grow operations.
  192. Index must be the change in number of processes as a positive
  193. (increasing) or negative (decreasing) number.
  194. .. note::
  195. Currently pool grow operations will end up with an offset
  196. of +1 if the initial size of the pool was 0 (e.g.
  197. ``--autoscale=1,0``).
  198. """
  199. num_processes = self.pool.num_processes
  200. if not self.initial_prefetch_count or not num_processes:
  201. return # prefetch disabled
  202. self.initial_prefetch_count = (
  203. self.pool.num_processes * self.prefetch_multiplier
  204. )
  205. return self._update_qos_eventually(index)
  206. def _update_qos_eventually(self, index):
  207. return (self.qos.decrement_eventually if index < 0
  208. else self.qos.increment_eventually)(
  209. abs(index) * self.prefetch_multiplier)
  210. def _limit_task(self, request, bucket, tokens):
  211. if not bucket.can_consume(tokens):
  212. hold = bucket.expected_time(tokens)
  213. self.timer.call_after(
  214. hold, self._limit_task, (request, bucket, tokens),
  215. )
  216. else:
  217. task_reserved(request)
  218. self.on_task_request(request)
  219. def start(self):
  220. blueprint = self.blueprint
  221. while blueprint.state != CLOSE:
  222. self.restart_count += 1
  223. maybe_shutdown()
  224. try:
  225. blueprint.start(self)
  226. except self.connection_errors as exc:
  227. if isinstance(exc, OSError) and get_errno(exc) == errno.EMFILE:
  228. raise # Too many open files
  229. maybe_shutdown()
  230. try:
  231. self._restart_state.step()
  232. except RestartFreqExceeded as exc:
  233. crit('Frequent restarts detected: %r', exc, exc_info=1)
  234. sleep(1)
  235. if blueprint.state != CLOSE and self.connection:
  236. warn(CONNECTION_RETRY, exc_info=True)
  237. try:
  238. self.connection.collect()
  239. except Exception:
  240. pass
  241. self.on_close()
  242. blueprint.restart(self)
  243. def register_with_event_loop(self, hub):
  244. self.blueprint.send_all(
  245. self, 'register_with_event_loop', args=(hub, ),
  246. description='Hub.register',
  247. )
  248. def shutdown(self):
  249. self.in_shutdown = True
  250. self.blueprint.shutdown(self)
  251. def stop(self):
  252. self.blueprint.stop(self)
  253. def on_ready(self):
  254. callback, self.init_callback = self.init_callback, None
  255. if callback:
  256. callback(self)
  257. def loop_args(self):
  258. return (self, self.connection, self.task_consumer,
  259. self.blueprint, self.hub, self.qos, self.amqheartbeat,
  260. self.app.clock, self.amqheartbeat_rate)
  261. def on_decode_error(self, message, exc):
  262. """Callback called if an error occurs while decoding
  263. a message received.
  264. Simply logs the error and acknowledges the message so it
  265. doesn't enter a loop.
  266. :param message: The message with errors.
  267. :param exc: The original exception instance.
  268. """
  269. crit(MESSAGE_DECODE_ERROR,
  270. exc, message.content_type, message.content_encoding,
  271. safe_repr(message.headers), dump_body(message, message.body),
  272. exc_info=1)
  273. message.ack()
  274. def on_close(self):
  275. # Clear internal queues to get rid of old messages.
  276. # They can't be acked anyway, as a delivery tag is specific
  277. # to the current channel.
  278. if self.controller and self.controller.semaphore:
  279. self.controller.semaphore.clear()
  280. if self.timer:
  281. self.timer.clear()
  282. reserved_requests.clear()
  283. if self.pool and self.pool.flush:
  284. self.pool.flush()
  285. def connect(self):
  286. """Establish the broker connection.
  287. Will retry establishing the connection if the
  288. :setting:`BROKER_CONNECTION_RETRY` setting is enabled
  289. """
  290. conn = self.app.connection(heartbeat=self.amqheartbeat)
  291. # Callback called for each retry while the connection
  292. # can't be established.
  293. def _error_handler(exc, interval, next_step=CONNECTION_RETRY_STEP):
  294. if getattr(conn, 'alt', None) and interval == 0:
  295. next_step = CONNECTION_FAILOVER
  296. error(CONNECTION_ERROR, conn.as_uri(), exc,
  297. next_step.format(when=humanize_seconds(interval, 'in', ' ')))
  298. # remember that the connection is lazy, it won't establish
  299. # until needed.
  300. if not self.app.conf.BROKER_CONNECTION_RETRY:
  301. # retry disabled, just call connect directly.
  302. conn.connect()
  303. return conn
  304. conn = conn.ensure_connection(
  305. _error_handler, self.app.conf.BROKER_CONNECTION_MAX_RETRIES,
  306. callback=maybe_shutdown,
  307. )
  308. if self.hub:
  309. conn.transport.register_with_event_loop(conn.connection, self.hub)
  310. return conn
  311. def add_task_queue(self, queue, exchange=None, exchange_type=None,
  312. routing_key=None, **options):
  313. cset = self.task_consumer
  314. queues = self.app.amqp.queues
  315. # Must use in' here, as __missing__ will automatically
  316. # create queues when CELERY_CREATE_MISSING_QUEUES is enabled.
  317. # (Issue #1079)
  318. if queue in queues:
  319. q = queues[queue]
  320. else:
  321. exchange = queue if exchange is None else exchange
  322. exchange_type = ('direct' if exchange_type is None
  323. else exchange_type)
  324. q = queues.select_add(queue,
  325. exchange=exchange,
  326. exchange_type=exchange_type,
  327. routing_key=routing_key, **options)
  328. if not cset.consuming_from(queue):
  329. cset.add_queue(q)
  330. cset.consume()
  331. info('Started consuming from %s', queue)
  332. def cancel_task_queue(self, queue):
  333. info('Canceling queue %s', queue)
  334. self.app.amqp.queues.deselect(queue)
  335. self.task_consumer.cancel_by_queue(queue)
  336. def apply_eta_task(self, task):
  337. """Method called by the timer to apply a task with an
  338. ETA/countdown."""
  339. task_reserved(task)
  340. self.on_task_request(task)
  341. self.qos.decrement_eventually()
  342. def _message_report(self, body, message):
  343. return MESSAGE_REPORT.format(dump_body(message, body),
  344. safe_repr(message.content_type),
  345. safe_repr(message.content_encoding),
  346. safe_repr(message.delivery_info),
  347. safe_repr(message.headers))
  348. def on_unknown_message(self, body, message):
  349. warn(UNKNOWN_FORMAT, self._message_report(body, message))
  350. message.reject_log_error(logger, self.connection_errors)
  351. def on_unknown_task(self, body, message, exc):
  352. error(UNKNOWN_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
  353. message.reject_log_error(logger, self.connection_errors)
  354. def on_invalid_task(self, body, message, exc):
  355. error(INVALID_TASK_ERROR, exc, dump_body(message, body), exc_info=True)
  356. message.reject_log_error(logger, self.connection_errors)
  357. def update_strategies(self):
  358. loader = self.app.loader
  359. for name, task in items(self.app.tasks):
  360. self.strategies[name] = task.start_strategy(self.app, self)
  361. task.__trace__ = build_tracer(name, task, loader, self.hostname,
  362. app=self.app)
  363. def create_task_handler(self):
  364. strategies = self.strategies
  365. on_unknown_message = self.on_unknown_message
  366. on_unknown_task = self.on_unknown_task
  367. on_invalid_task = self.on_invalid_task
  368. callbacks = self.on_task_message
  369. def on_task_received(body, message):
  370. headers = message.headers
  371. try:
  372. type_, is_proto2 = body['task'], 0
  373. except (KeyError, TypeError):
  374. try:
  375. type_, is_proto2 = headers['task'], 1
  376. except (KeyError, TypeError):
  377. return on_unknown_message(body, message)
  378. if is_proto2:
  379. body = proto2_to_proto1(
  380. self.app, type_, body, message, headers)
  381. try:
  382. strategies[type_](message, body,
  383. message.ack_log_error,
  384. message.reject_log_error,
  385. callbacks)
  386. except KeyError as exc:
  387. on_unknown_task(body, message, exc)
  388. except InvalidTaskError as exc:
  389. on_invalid_task(body, message, exc)
  390. return on_task_received
  391. def __repr__(self):
  392. return '<Consumer: {self.hostname} ({state})>'.format(
  393. self=self, state=self.blueprint.human_state(),
  394. )
  395. def proto2_to_proto1(app, type_, body, message, headers):
  396. args, kwargs, embed = body
  397. embedded = _extract_proto2_embed(**embed)
  398. chained = embedded.pop('chain')
  399. new_body = dict(
  400. _extract_proto2_headers(type_, **headers),
  401. args=args,
  402. kwargs=kwargs,
  403. **embedded)
  404. if chained:
  405. new_body['callbacks'].append(chain(chained, app=app))
  406. return new_body
  407. def _extract_proto2_headers(type_, id, retries, eta, expires,
  408. group, timelimit, **_):
  409. return {
  410. 'id': id,
  411. 'task': type_,
  412. 'retries': retries,
  413. 'eta': eta,
  414. 'expires': expires,
  415. 'utc': True,
  416. 'taskset': group,
  417. 'timelimit': timelimit,
  418. }
  419. def _extract_proto2_embed(callbacks, errbacks, chain, chord, **_):
  420. return {
  421. 'callbacks': callbacks or [],
  422. 'errbacks': errbacks,
  423. 'chain': chain,
  424. 'chord': chord,
  425. }
  426. class Connection(bootsteps.StartStopStep):
  427. def __init__(self, c, **kwargs):
  428. c.connection = None
  429. def start(self, c):
  430. c.connection = c.connect()
  431. info('Connected to %s', c.connection.as_uri())
  432. def shutdown(self, c):
  433. # We must set self.connection to None here, so
  434. # that the green pidbox thread exits.
  435. connection, c.connection = c.connection, None
  436. if connection:
  437. ignore_errors(connection, connection.close)
  438. def info(self, c, params='N/A'):
  439. if c.connection:
  440. params = c.connection.info()
  441. params.pop('password', None) # don't send password.
  442. return {'broker': params}
  443. class Events(bootsteps.StartStopStep):
  444. requires = (Connection, )
  445. def __init__(self, c, send_events=None, **kwargs):
  446. self.send_events = True
  447. self.groups = None if send_events else ['worker']
  448. c.event_dispatcher = None
  449. def start(self, c):
  450. # flush events sent while connection was down.
  451. prev = self._close(c)
  452. dis = c.event_dispatcher = c.app.events.Dispatcher(
  453. c.connect(), hostname=c.hostname,
  454. enabled=self.send_events, groups=self.groups,
  455. )
  456. if prev:
  457. dis.extend_buffer(prev)
  458. dis.flush()
  459. def stop(self, c):
  460. pass
  461. def _close(self, c):
  462. if c.event_dispatcher:
  463. dispatcher = c.event_dispatcher
  464. # remember changes from remote control commands:
  465. self.groups = dispatcher.groups
  466. # close custom connection
  467. if dispatcher.connection:
  468. ignore_errors(c, dispatcher.connection.close)
  469. ignore_errors(c, dispatcher.close)
  470. c.event_dispatcher = None
  471. return dispatcher
  472. def shutdown(self, c):
  473. self._close(c)
  474. class Heart(bootsteps.StartStopStep):
  475. requires = (Events, )
  476. def __init__(self, c, without_heartbeat=False, heartbeat_interval=None,
  477. **kwargs):
  478. self.enabled = not without_heartbeat
  479. self.heartbeat_interval = heartbeat_interval
  480. c.heart = None
  481. def start(self, c):
  482. c.heart = heartbeat.Heart(
  483. c.timer, c.event_dispatcher, self.heartbeat_interval,
  484. )
  485. c.heart.start()
  486. def stop(self, c):
  487. c.heart = c.heart and c.heart.stop()
  488. shutdown = stop
  489. class Mingle(bootsteps.StartStopStep):
  490. label = 'Mingle'
  491. requires = (Events, )
  492. compatible_transports = set(['amqp', 'redis'])
  493. def __init__(self, c, without_mingle=False, **kwargs):
  494. self.enabled = not without_mingle and self.compatible_transport(c.app)
  495. def compatible_transport(self, app):
  496. with app.connection() as conn:
  497. return conn.transport.driver_type in self.compatible_transports
  498. def start(self, c):
  499. info('mingle: searching for neighbors')
  500. I = c.app.control.inspect(timeout=1.0, connection=c.connection)
  501. replies = I.hello(c.hostname, revoked._data) or {}
  502. replies.pop(c.hostname, None)
  503. if replies:
  504. info('mingle: sync with %s nodes',
  505. len([reply for reply, value in items(replies) if value]))
  506. for reply in values(replies):
  507. if reply:
  508. try:
  509. other_clock, other_revoked = MINGLE_GET_FIELDS(reply)
  510. except KeyError: # reply from pre-3.1 worker
  511. pass
  512. else:
  513. c.app.clock.adjust(other_clock)
  514. revoked.update(other_revoked)
  515. info('mingle: sync complete')
  516. else:
  517. info('mingle: all alone')
  518. class Tasks(bootsteps.StartStopStep):
  519. requires = (Mingle, )
  520. def __init__(self, c, **kwargs):
  521. c.task_consumer = c.qos = None
  522. def start(self, c):
  523. c.update_strategies()
  524. # - RabbitMQ 3.3 completely redefines how basic_qos works..
  525. # This will detect if the new qos smenatics is in effect,
  526. # and if so make sure the 'apply_global' flag is set on qos updates.
  527. qos_global = not c.connection.qos_semantics_matches_spec
  528. # set initial prefetch count
  529. c.connection.default_channel.basic_qos(
  530. 0, c.initial_prefetch_count, qos_global,
  531. )
  532. c.task_consumer = c.app.amqp.TaskConsumer(
  533. c.connection, on_decode_error=c.on_decode_error,
  534. )
  535. def set_prefetch_count(prefetch_count):
  536. return c.task_consumer.qos(
  537. prefetch_count=prefetch_count,
  538. apply_global=qos_global,
  539. )
  540. c.qos = QoS(set_prefetch_count, c.initial_prefetch_count)
  541. def stop(self, c):
  542. if c.task_consumer:
  543. debug('Canceling task consumer...')
  544. ignore_errors(c, c.task_consumer.cancel)
  545. def shutdown(self, c):
  546. if c.task_consumer:
  547. self.stop(c)
  548. debug('Closing consumer channel...')
  549. ignore_errors(c, c.task_consumer.close)
  550. c.task_consumer = None
  551. def info(self, c):
  552. return {'prefetch_count': c.qos.value if c.qos else 'N/A'}
  553. class Agent(bootsteps.StartStopStep):
  554. conditional = True
  555. requires = (Connection, )
  556. def __init__(self, c, **kwargs):
  557. self.agent_cls = self.enabled = c.app.conf.CELERYD_AGENT
  558. def create(self, c):
  559. agent = c.agent = self.instantiate(self.agent_cls, c.connection)
  560. return agent
  561. class Control(bootsteps.StartStopStep):
  562. requires = (Tasks, )
  563. def __init__(self, c, **kwargs):
  564. self.is_green = c.pool is not None and c.pool.is_green
  565. self.box = (pidbox.gPidbox if self.is_green else pidbox.Pidbox)(c)
  566. self.start = self.box.start
  567. self.stop = self.box.stop
  568. self.shutdown = self.box.shutdown
  569. def include_if(self, c):
  570. return c.app.conf.CELERY_ENABLE_REMOTE_CONTROL
  571. class Gossip(bootsteps.ConsumerStep):
  572. label = 'Gossip'
  573. requires = (Mingle, )
  574. _cons_stamp_fields = itemgetter(
  575. 'id', 'clock', 'hostname', 'pid', 'topic', 'action', 'cver',
  576. )
  577. compatible_transports = set(['amqp', 'redis'])
  578. def __init__(self, c, without_gossip=False, interval=5.0, **kwargs):
  579. self.enabled = not without_gossip and self.compatible_transport(c.app)
  580. self.app = c.app
  581. c.gossip = self
  582. self.Receiver = c.app.events.Receiver
  583. self.hostname = c.hostname
  584. self.full_hostname = '.'.join([self.hostname, str(c.pid)])
  585. self.on = Bunch(
  586. node_join=set(),
  587. node_leave=set(),
  588. node_lost=set(),
  589. )
  590. self.timer = c.timer
  591. if self.enabled:
  592. self.state = c.app.events.State(
  593. on_node_join=self.on_node_join,
  594. on_node_leave=self.on_node_leave,
  595. max_tasks_in_memory=1,
  596. )
  597. if c.hub:
  598. c._mutex = DummyLock()
  599. self.update_state = self.state.event
  600. self.interval = interval
  601. self._tref = None
  602. self.consensus_requests = defaultdict(list)
  603. self.consensus_replies = {}
  604. self.event_handlers = {
  605. 'worker.elect': self.on_elect,
  606. 'worker.elect.ack': self.on_elect_ack,
  607. }
  608. self.clock = c.app.clock
  609. self.election_handlers = {
  610. 'task': self.call_task
  611. }
  612. def compatible_transport(self, app):
  613. with app.connection() as conn:
  614. return conn.transport.driver_type in self.compatible_transports
  615. def election(self, id, topic, action=None):
  616. self.consensus_replies[id] = []
  617. self.dispatcher.send(
  618. 'worker-elect',
  619. id=id, topic=topic, action=action, cver=1,
  620. )
  621. def call_task(self, task):
  622. try:
  623. signature(task, app=self.app).apply_async()
  624. except Exception as exc:
  625. error('Could not call task: %r', exc, exc_info=1)
  626. def on_elect(self, event):
  627. try:
  628. (id_, clock, hostname, pid,
  629. topic, action, _) = self._cons_stamp_fields(event)
  630. except KeyError as exc:
  631. return error('election request missing field %s', exc, exc_info=1)
  632. heappush(
  633. self.consensus_requests[id_],
  634. (clock, '%s.%s' % (hostname, pid), topic, action),
  635. )
  636. self.dispatcher.send('worker-elect-ack', id=id_)
  637. def start(self, c):
  638. super(Gossip, self).start(c)
  639. self.dispatcher = c.event_dispatcher
  640. def on_elect_ack(self, event):
  641. id = event['id']
  642. try:
  643. replies = self.consensus_replies[id]
  644. except KeyError:
  645. return # not for us
  646. alive_workers = self.state.alive_workers()
  647. replies.append(event['hostname'])
  648. if len(replies) >= len(alive_workers):
  649. _, leader, topic, action = self.clock.sort_heap(
  650. self.consensus_requests[id],
  651. )
  652. if leader == self.full_hostname:
  653. info('I won the election %r', id)
  654. try:
  655. handler = self.election_handlers[topic]
  656. except KeyError:
  657. error('Unknown election topic %r', topic, exc_info=1)
  658. else:
  659. handler(action)
  660. else:
  661. info('node %s elected for %r', leader, id)
  662. self.consensus_requests.pop(id, None)
  663. self.consensus_replies.pop(id, None)
  664. def on_node_join(self, worker):
  665. debug('%s joined the party', worker.hostname)
  666. self._call_handlers(self.on.node_join, worker)
  667. def on_node_leave(self, worker):
  668. debug('%s left', worker.hostname)
  669. self._call_handlers(self.on.node_leave, worker)
  670. def on_node_lost(self, worker):
  671. info('missed heartbeat from %s', worker.hostname)
  672. self._call_handlers(self.on.node_lost, worker)
  673. def _call_handlers(self, handlers, *args, **kwargs):
  674. for handler in handlers:
  675. try:
  676. handler(*args, **kwargs)
  677. except Exception as exc:
  678. error('Ignored error from handler %r: %r',
  679. handler, exc, exc_info=1)
  680. def register_timer(self):
  681. if self._tref is not None:
  682. self._tref.cancel()
  683. self._tref = self.timer.call_repeatedly(self.interval, self.periodic)
  684. def periodic(self):
  685. workers = self.state.workers
  686. dirty = set()
  687. for worker in values(workers):
  688. if not worker.alive:
  689. dirty.add(worker)
  690. self.on_node_lost(worker)
  691. for worker in dirty:
  692. workers.pop(worker.hostname, None)
  693. def get_consumers(self, channel):
  694. self.register_timer()
  695. ev = self.Receiver(channel, routing_key='worker.#')
  696. return [kombu.Consumer(
  697. channel,
  698. queues=[ev.queue],
  699. on_message=partial(self.on_message, ev.event_from_message),
  700. no_ack=True
  701. )]
  702. def on_message(self, prepare, message):
  703. _type = message.delivery_info['routing_key']
  704. # For redis when `fanout_patterns=False` (See Issue #1882)
  705. if _type.split('.', 1)[0] == 'task':
  706. return
  707. try:
  708. handler = self.event_handlers[_type]
  709. except KeyError:
  710. pass
  711. else:
  712. return handler(message.payload)
  713. hostname = (message.headers.get('hostname') or
  714. message.payload['hostname'])
  715. if hostname != self.hostname:
  716. type, event = prepare(message.payload)
  717. self.update_state(event)
  718. else:
  719. self.clock.forward()
  720. class Evloop(bootsteps.StartStopStep):
  721. label = 'event loop'
  722. last = True
  723. def start(self, c):
  724. self.patch_all(c)
  725. c.loop(*c.loop_args())
  726. def patch_all(self, c):
  727. c.qos._mutex = DummyLock()