123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- from __future__ import absolute_import
-
- import os
- import signal
- import sys
-
- from contextlib import contextmanager
- from time import time
-
- from nose import SkipTest
- from billiard.common import (
- _shutdown_cleanup,
- reset_signals,
- restart_state,
- )
-
- from .utils import Case
-
- try:
- from unittest.mock import Mock, call, patch
- except ImportError:
- from mock import Mock, call, patch # noqa
-
-
- def signo(name):
- return getattr(signal, name)
-
-
- @contextmanager
- def termsigs(default, full):
- from billiard import common
- prev_def, common.TERMSIGS_DEFAULT = common.TERMSIGS_DEFAULT, default
- prev_full, common.TERMSIGS_FULL = common.TERMSIGS_FULL, full
- try:
- yield
- finally:
- common.TERMSIGS_DEFAULT, common.TERMSIGS_FULL = prev_def, prev_full
-
-
- class test_reset_signals(Case):
-
- def setUp(self):
- if sys.platform == 'win32':
- raise SkipTest('win32: skip')
-
- def test_shutdown_handler(self):
- with patch('sys.exit') as exit:
- _shutdown_cleanup(15, Mock())
- self.assertTrue(exit.called)
- self.assertEqual(os.WTERMSIG(exit.call_args[0][0]), 15)
-
- def test_does_not_reset_ignored_signal(self, sigs=['SIGTERM']):
- with self.assert_context(sigs, [], signal.SIG_IGN) as (_, SET):
- self.assertFalse(SET.called)
-
- def test_does_not_reset_if_current_is_None(self, sigs=['SIGTERM']):
- with self.assert_context(sigs, [], None) as (_, SET):
- self.assertFalse(SET.called)
-
- def test_resets_for_SIG_DFL(self, sigs=['SIGTERM', 'SIGINT', 'SIGUSR1']):
- with self.assert_context(sigs, [], signal.SIG_DFL) as (_, SET):
- SET.assert_has_calls([
- call(signo(sig), _shutdown_cleanup) for sig in sigs
- ])
-
- def test_resets_for_obj(self, sigs=['SIGTERM', 'SIGINT', 'SIGUSR1']):
- with self.assert_context(sigs, [], object()) as (_, SET):
- SET.assert_has_calls([
- call(signo(sig), _shutdown_cleanup) for sig in sigs
- ])
-
- def test_handles_errors(self, sigs=['SIGTERM']):
- for exc in (OSError(), AttributeError(),
- ValueError(), RuntimeError()):
- with self.assert_context(sigs, [], signal.SIG_DFL, exc) as (_, S):
- self.assertTrue(S.called)
-
- @contextmanager
- def assert_context(self, default, full, get_returns=None, set_effect=None):
- with termsigs(default, full):
- with patch('signal.getsignal') as GET:
- with patch('signal.signal') as SET:
- GET.return_value = get_returns
- SET.side_effect = set_effect
- reset_signals()
- GET.assert_has_calls([
- call(signo(sig)) for sig in default
- ])
- yield GET, SET
-
-
- class test_restart_state(Case):
-
- def test_raises(self):
- s = restart_state(100, 1) # max 100 restarts in 1 second.
- s.R = 99
- s.step()
- with self.assertRaises(s.RestartFreqExceeded):
- s.step()
-
- def test_time_passed_resets_counter(self):
- s = restart_state(100, 10)
- s.R, s.T = 100, time()
- with self.assertRaises(s.RestartFreqExceeded):
- s.step()
- s.R, s.T = 100, time()
- s.step(time() + 20)
- self.assertEqual(s.R, 1)
|