Funktionierender Prototyp des Serious Games zur Vermittlung von Wissen zu Software-Engineering-Arbeitsmodellen.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_amp.py 108KB

1 year ago

  1. # Copyright (c) 2005 Divmod, Inc.
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Tests for L{twisted.protocols.amp}.
  6. """
  7. import datetime
  8. import decimal
  9. from typing import Dict, Type
  10. from unittest import skipIf
  11. from zope.interface import implementer
  12. from zope.interface.verify import verifyClass, verifyObject
  13. from twisted.internet import address, defer, error, interfaces, protocol, reactor
  14. from twisted.protocols import amp
  15. from twisted.python import filepath
  16. from twisted.python.failure import Failure
  17. from twisted.test import iosim
  18. from twisted.test.proto_helpers import StringTransport
  19. from twisted.trial.unittest import TestCase
  20. try:
  21. from twisted.internet import ssl as _ssl
  22. except ImportError:
  23. ssl = None
  24. else:
  25. if not _ssl.supported:
  26. ssl = None
  27. else:
  28. ssl = _ssl
  29. if ssl is None:
  30. skipSSL = True
  31. else:
  32. skipSSL = False
  33. if not interfaces.IReactorSSL.providedBy(reactor):
  34. reactorLacksSSL = True
  35. else:
  36. reactorLacksSSL = False
  37. tz = amp._FixedOffsetTZInfo.fromSignHoursMinutes
  38. class TestProto(protocol.Protocol):
  39. """
  40. A trivial protocol for use in testing where a L{Protocol} is expected.
  41. @ivar instanceId: the id of this instance
  42. @ivar onConnLost: deferred that will fired when the connection is lost
  43. @ivar dataToSend: data to send on the protocol
  44. """
  45. instanceCount = 0
  46. def __init__(self, onConnLost, dataToSend):
  47. assert isinstance(dataToSend, bytes), repr(dataToSend)
  48. self.onConnLost = onConnLost
  49. self.dataToSend = dataToSend
  50. self.instanceId = TestProto.instanceCount
  51. TestProto.instanceCount = TestProto.instanceCount + 1
  52. def connectionMade(self):
  53. self.data = []
  54. self.transport.write(self.dataToSend)
  55. def dataReceived(self, bytes):
  56. self.data.append(bytes)
  57. def connectionLost(self, reason):
  58. self.onConnLost.callback(self.data)
  59. def __repr__(self) -> str:
  60. """
  61. Custom repr for testing to avoid coupling amp tests with repr from
  62. L{Protocol}
  63. Returns a string which contains a unique identifier that can be looked
  64. up using the instanceId property::
  65. <TestProto #3>
  66. """
  67. return "<TestProto #%d>" % (self.instanceId,)
  68. class SimpleSymmetricProtocol(amp.AMP):
  69. def sendHello(self, text):
  70. return self.callRemoteString(b"hello", hello=text)
  71. def amp_HELLO(self, box):
  72. return amp.Box(hello=box[b"hello"])
  73. class UnfriendlyGreeting(Exception):
  74. """Greeting was insufficiently kind."""
  75. class DeathThreat(Exception):
  76. """Greeting was insufficiently kind."""
  77. class UnknownProtocol(Exception):
  78. """Asked to switch to the wrong protocol."""
  79. class TransportPeer(amp.Argument):
  80. # this serves as some informal documentation for how to get variables from
  81. # the protocol or your environment and pass them to methods as arguments.
  82. def retrieve(self, d, name, proto):
  83. return b""
  84. def fromStringProto(self, notAString, proto):
  85. return proto.transport.getPeer()
  86. def toBox(self, name, strings, objects, proto):
  87. return
  88. class Hello(amp.Command):
  89. commandName = b"hello"
  90. arguments = [
  91. (b"hello", amp.String()),
  92. (b"optional", amp.Boolean(optional=True)),
  93. (b"print", amp.Unicode(optional=True)),
  94. (b"from", TransportPeer(optional=True)),
  95. (b"mixedCase", amp.String(optional=True)),
  96. (b"dash-arg", amp.String(optional=True)),
  97. (b"underscore_arg", amp.String(optional=True)),
  98. ]
  99. response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))]
  100. errors: Dict[Type[Exception], bytes] = {UnfriendlyGreeting: b"UNFRIENDLY"}
  101. fatalErrors: Dict[Type[Exception], bytes] = {DeathThreat: b"DEAD"}
  102. class NoAnswerHello(Hello):
  103. commandName = Hello.commandName
  104. requiresAnswer = False
  105. class FutureHello(amp.Command):
  106. commandName = b"hello"
  107. arguments = [
  108. (b"hello", amp.String()),
  109. (b"optional", amp.Boolean(optional=True)),
  110. (b"print", amp.Unicode(optional=True)),
  111. (b"from", TransportPeer(optional=True)),
  112. (b"bonus", amp.String(optional=True)), # addt'l arguments
  113. # should generally be
  114. # added at the end, and
  115. # be optional...
  116. ]
  117. response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))]
  118. errors = {UnfriendlyGreeting: b"UNFRIENDLY"}
  119. class WTF(amp.Command):
  120. """
  121. An example of an invalid command.
  122. """
  123. class BrokenReturn(amp.Command):
  124. """An example of a perfectly good command, but the handler is going to return
  125. None...
  126. """
  127. commandName = b"broken_return"
  128. class Goodbye(amp.Command):
  129. # commandName left blank on purpose: this tests implicit command names.
  130. response = [(b"goodbye", amp.String())]
  131. responseType = amp.QuitBox
  132. class WaitForever(amp.Command):
  133. commandName = b"wait_forever"
  134. class GetList(amp.Command):
  135. commandName = b"getlist"
  136. arguments = [(b"length", amp.Integer())]
  137. response = [(b"body", amp.AmpList([(b"x", amp.Integer())]))]
  138. class DontRejectMe(amp.Command):
  139. commandName = b"dontrejectme"
  140. arguments = [
  141. (b"magicWord", amp.Unicode()),
  142. (b"list", amp.AmpList([(b"name", amp.Unicode())], optional=True)),
  143. ]
  144. response = [(b"response", amp.Unicode())]
  145. class SecuredPing(amp.Command):
  146. # XXX TODO: actually make this refuse to send over an insecure connection
  147. response = [(b"pinged", amp.Boolean())]
  148. class TestSwitchProto(amp.ProtocolSwitchCommand):
  149. commandName = b"Switch-Proto"
  150. arguments = [
  151. (b"name", amp.String()),
  152. ]
  153. errors = {UnknownProtocol: b"UNKNOWN"}
  154. class SingleUseFactory(protocol.ClientFactory):
  155. def __init__(self, proto):
  156. self.proto = proto
  157. self.proto.factory = self
  158. def buildProtocol(self, addr):
  159. p, self.proto = self.proto, None
  160. return p
  161. reasonFailed = None
  162. def clientConnectionFailed(self, connector, reason):
  163. self.reasonFailed = reason
  164. return
  165. THING_I_DONT_UNDERSTAND = b"gwebol nargo"
  166. class ThingIDontUnderstandError(Exception):
  167. pass
  168. class FactoryNotifier(amp.AMP):
  169. factory = None
  170. def connectionMade(self):
  171. if self.factory is not None:
  172. self.factory.theProto = self
  173. if hasattr(self.factory, "onMade"):
  174. self.factory.onMade.callback(None)
  175. def emitpong(self):
  176. from twisted.internet.interfaces import ISSLTransport
  177. if not ISSLTransport.providedBy(self.transport):
  178. raise DeathThreat("only send secure pings over secure channels")
  179. return {"pinged": True}
  180. SecuredPing.responder(emitpong)
  181. class SimpleSymmetricCommandProtocol(FactoryNotifier):
  182. maybeLater = None
  183. def __init__(self, onConnLost=None):
  184. amp.AMP.__init__(self)
  185. self.onConnLost = onConnLost
  186. def sendHello(self, text):
  187. return self.callRemote(Hello, hello=text)
  188. def sendUnicodeHello(self, text, translation):
  189. return self.callRemote(Hello, hello=text, Print=translation)
  190. greeted = False
  191. def cmdHello(
  192. self,
  193. hello,
  194. From,
  195. optional=None,
  196. Print=None,
  197. mixedCase=None,
  198. dash_arg=None,
  199. underscore_arg=None,
  200. ):
  201. assert From == self.transport.getPeer()
  202. if hello == THING_I_DONT_UNDERSTAND:
  203. raise ThingIDontUnderstandError()
  204. if hello.startswith(b"fuck"):
  205. raise UnfriendlyGreeting("Don't be a dick.")
  206. if hello == b"die":
  207. raise DeathThreat("aieeeeeeeee")
  208. result = dict(hello=hello)
  209. if Print is not None:
  210. result.update(dict(Print=Print))
  211. self.greeted = True
  212. return result
  213. Hello.responder(cmdHello)
  214. def cmdGetlist(self, length):
  215. return {"body": [dict(x=1)] * length}
  216. GetList.responder(cmdGetlist)
  217. def okiwont(self, magicWord, list=None):
  218. if list is None:
  219. response = "list omitted"
  220. else:
  221. response = "%s accepted" % (list[0]["name"])
  222. return dict(response=response)
  223. DontRejectMe.responder(okiwont)
  224. def waitforit(self):
  225. self.waiting = defer.Deferred()
  226. return self.waiting
  227. WaitForever.responder(waitforit)
  228. def saybye(self):
  229. return dict(goodbye=b"everyone")
  230. Goodbye.responder(saybye)
  231. def switchToTestProtocol(self, fail=False):
  232. if fail:
  233. name = b"no-proto"
  234. else:
  235. name = b"test-proto"
  236. p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
  237. return self.callRemote(
  238. TestSwitchProto, SingleUseFactory(p), name=name
  239. ).addCallback(lambda ign: p)
  240. def switchit(self, name):
  241. if name == b"test-proto":
  242. return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
  243. raise UnknownProtocol(name)
  244. TestSwitchProto.responder(switchit)
  245. def donothing(self):
  246. return None
  247. BrokenReturn.responder(donothing)
  248. class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
  249. def switchit(self, name):
  250. if name == b"test-proto":
  251. self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
  252. self.maybeLater = defer.Deferred()
  253. return self.maybeLater
  254. TestSwitchProto.responder(switchit)
  255. class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
  256. def badResponder(
  257. self,
  258. hello,
  259. From,
  260. optional=None,
  261. Print=None,
  262. mixedCase=None,
  263. dash_arg=None,
  264. underscore_arg=None,
  265. ):
  266. """
  267. This responder does nothing and forgets to return a dictionary.
  268. """
  269. NoAnswerHello.responder(badResponder)
  270. class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
  271. def goodNoAnswerResponder(
  272. self,
  273. hello,
  274. From,
  275. optional=None,
  276. Print=None,
  277. mixedCase=None,
  278. dash_arg=None,
  279. underscore_arg=None,
  280. ):
  281. return dict(hello=hello + b"-noanswer")
  282. NoAnswerHello.responder(goodNoAnswerResponder)
  283. def connectedServerAndClient(
  284. ServerClass=SimpleSymmetricProtocol, ClientClass=SimpleSymmetricProtocol, *a, **kw
  285. ):
  286. """Returns a 3-tuple: (client, server, pump)"""
  287. return iosim.connectedServerAndClient(ServerClass, ClientClass, *a, **kw)
  288. class TotallyDumbProtocol(protocol.Protocol):
  289. buf = b""
  290. def dataReceived(self, data):
  291. self.buf += data
  292. class LiteralAmp(amp.AMP):
  293. def __init__(self):
  294. self.boxes = []
  295. def ampBoxReceived(self, box):
  296. self.boxes.append(box)
  297. return
  298. class AmpBoxTests(TestCase):
  299. """
  300. Test a few essential properties of AMP boxes, mostly with respect to
  301. serialization correctness.
  302. """
  303. def test_serializeStr(self):
  304. """
  305. Make sure that strs serialize to strs.
  306. """
  307. a = amp.AmpBox(key=b"value")
  308. self.assertEqual(type(a.serialize()), bytes)
  309. def test_serializeUnicodeKeyRaises(self):
  310. """
  311. Verify that TypeError is raised when trying to serialize Unicode keys.
  312. """
  313. a = amp.AmpBox(**{"key": "value"})
  314. self.assertRaises(TypeError, a.serialize)
  315. def test_serializeUnicodeValueRaises(self):
  316. """
  317. Verify that TypeError is raised when trying to serialize Unicode
  318. values.
  319. """
  320. a = amp.AmpBox(key="value")
  321. self.assertRaises(TypeError, a.serialize)
  322. class ParsingTests(TestCase):
  323. def test_booleanValues(self):
  324. """
  325. Verify that the Boolean parser parses 'True' and 'False', but nothing
  326. else.
  327. """
  328. b = amp.Boolean()
  329. self.assertTrue(b.fromString(b"True"))
  330. self.assertFalse(b.fromString(b"False"))
  331. self.assertRaises(TypeError, b.fromString, b"ninja")
  332. self.assertRaises(TypeError, b.fromString, b"true")
  333. self.assertRaises(TypeError, b.fromString, b"TRUE")
  334. self.assertEqual(b.toString(True), b"True")
  335. self.assertEqual(b.toString(False), b"False")
  336. def test_pathValueRoundTrip(self):
  337. """
  338. Verify the 'Path' argument can parse and emit a file path.
  339. """
  340. fp = filepath.FilePath(self.mktemp())
  341. p = amp.Path()
  342. s = p.toString(fp)
  343. v = p.fromString(s)
  344. self.assertIsNot(fp, v) # sanity check
  345. self.assertEqual(fp, v)
  346. def test_sillyEmptyThing(self):
  347. """
  348. Test that empty boxes raise an error; they aren't supposed to be sent
  349. on purpose.
  350. """
  351. a = amp.AMP()
  352. return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
  353. def test_ParsingRoundTrip(self):
  354. """
  355. Verify that various kinds of data make it through the encode/parse
  356. round-trip unharmed.
  357. """
  358. c, s, p = connectedServerAndClient(
  359. ClientClass=LiteralAmp, ServerClass=LiteralAmp
  360. )
  361. SIMPLE = (b"simple", b"test")
  362. CE = (b"ceq", b": ")
  363. CR = (b"crtest", b"test\r")
  364. LF = (b"lftest", b"hello\n")
  365. NEWLINE = (b"newline", b"test\r\none\r\ntwo")
  366. NEWLINE2 = (b"newline2", b"test\r\none\r\n two")
  367. BODYTEST = (b"body", b"blah\r\n\r\ntesttest")
  368. testData = [
  369. [SIMPLE],
  370. [SIMPLE, BODYTEST],
  371. [SIMPLE, CE],
  372. [SIMPLE, CR],
  373. [SIMPLE, CE, CR, LF],
  374. [CE, CR, LF],
  375. [SIMPLE, NEWLINE, CE, NEWLINE2],
  376. [BODYTEST, SIMPLE, NEWLINE],
  377. ]
  378. for test in testData:
  379. jb = amp.Box()
  380. jb.update(dict(test))
  381. jb._sendTo(c)
  382. p.flush()
  383. self.assertEqual(s.boxes[-1], jb)
  384. class FakeLocator:
  385. """
  386. This is a fake implementation of the interface implied by
  387. L{CommandLocator}.
  388. """
  389. def __init__(self):
  390. """
  391. Remember the given keyword arguments as a set of responders.
  392. """
  393. self.commands = {}
  394. def locateResponder(self, commandName):
  395. """
  396. Look up and return a function passed as a keyword argument of the given
  397. name to the constructor.
  398. """
  399. return self.commands[commandName]
  400. class FakeSender:
  401. """
  402. This is a fake implementation of the 'box sender' interface implied by
  403. L{AMP}.
  404. """
  405. def __init__(self):
  406. """
  407. Create a fake sender and initialize the list of received boxes and
  408. unhandled errors.
  409. """
  410. self.sentBoxes = []
  411. self.unhandledErrors = []
  412. self.expectedErrors = 0
  413. def expectError(self):
  414. """
  415. Expect one error, so that the test doesn't fail.
  416. """
  417. self.expectedErrors += 1
  418. def sendBox(self, box):
  419. """
  420. Accept a box, but don't do anything.
  421. """
  422. self.sentBoxes.append(box)
  423. def unhandledError(self, failure):
  424. """
  425. Deal with failures by instantly re-raising them for easier debugging.
  426. """
  427. self.expectedErrors -= 1
  428. if self.expectedErrors < 0:
  429. failure.raiseException()
  430. else:
  431. self.unhandledErrors.append(failure)
  432. class CommandDispatchTests(TestCase):
  433. """
  434. The AMP CommandDispatcher class dispatches converts AMP boxes into commands
  435. and responses using Command.responder decorator.
  436. Note: Originally, AMP's factoring was such that many tests for this
  437. functionality are now implemented as full round-trip tests in L{AMPTests}.
  438. Future tests should be written at this level instead, to ensure API
  439. compatibility and to provide more granular, readable units of test
  440. coverage.
  441. """
  442. def setUp(self):
  443. """
  444. Create a dispatcher to use.
  445. """
  446. self.locator = FakeLocator()
  447. self.sender = FakeSender()
  448. self.dispatcher = amp.BoxDispatcher(self.locator)
  449. self.dispatcher.startReceivingBoxes(self.sender)
  450. def test_receivedAsk(self):
  451. """
  452. L{CommandDispatcher.ampBoxReceived} should locate the appropriate
  453. command in its responder lookup, based on the '_ask' key.
  454. """
  455. received = []
  456. def thunk(box):
  457. received.append(box)
  458. return amp.Box({"hello": "goodbye"})
  459. input = amp.Box(_command="hello", _ask="test-command-id", hello="world")
  460. self.locator.commands["hello"] = thunk
  461. self.dispatcher.ampBoxReceived(input)
  462. self.assertEqual(received, [input])
  463. def test_sendUnhandledError(self):
  464. """
  465. L{CommandDispatcher} should relay its unhandled errors in responding to
  466. boxes to its boxSender.
  467. """
  468. err = RuntimeError("something went wrong, oh no")
  469. self.sender.expectError()
  470. self.dispatcher.unhandledError(Failure(err))
  471. self.assertEqual(len(self.sender.unhandledErrors), 1)
  472. self.assertEqual(self.sender.unhandledErrors[0].value, err)
  473. def test_unhandledSerializationError(self):
  474. """
  475. Errors during serialization ought to be relayed to the sender's
  476. unhandledError method.
  477. """
  478. err = RuntimeError("something undefined went wrong")
  479. def thunk(result):
  480. class BrokenBox(amp.Box):
  481. def _sendTo(self, proto):
  482. raise err
  483. return BrokenBox()
  484. self.locator.commands["hello"] = thunk
  485. input = amp.Box(_command="hello", _ask="test-command-id", hello="world")
  486. self.sender.expectError()
  487. self.dispatcher.ampBoxReceived(input)
  488. self.assertEqual(len(self.sender.unhandledErrors), 1)
  489. self.assertEqual(self.sender.unhandledErrors[0].value, err)
  490. def test_callRemote(self):
  491. """
  492. L{CommandDispatcher.callRemote} should emit a properly formatted '_ask'
  493. box to its boxSender and record an outstanding L{Deferred}. When a
  494. corresponding '_answer' packet is received, the L{Deferred} should be
  495. fired, and the results translated via the given L{Command}'s response
  496. de-serialization.
  497. """
  498. D = self.dispatcher.callRemote(Hello, hello=b"world")
  499. self.assertEqual(
  500. self.sender.sentBoxes,
  501. [amp.AmpBox(_command=b"hello", _ask=b"1", hello=b"world")],
  502. )
  503. answers = []
  504. D.addCallback(answers.append)
  505. self.assertEqual(answers, [])
  506. self.dispatcher.ampBoxReceived(
  507. amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"})
  508. )
  509. self.assertEqual(answers, [dict(hello=b"yay", Print="ignored")])
  510. def _localCallbackErrorLoggingTest(self, callResult):
  511. """
  512. Verify that C{callResult} completes with a L{None} result and that an
  513. unhandled error has been logged.
  514. """
  515. finalResult = []
  516. callResult.addBoth(finalResult.append)
  517. self.assertEqual(1, len(self.sender.unhandledErrors))
  518. self.assertIsInstance(self.sender.unhandledErrors[0].value, ZeroDivisionError)
  519. self.assertEqual([None], finalResult)
  520. def test_callRemoteSuccessLocalCallbackErrorLogging(self):
  521. """
  522. If the last callback on the L{Deferred} returned by C{callRemote} (added
  523. by application code calling C{callRemote}) fails, the failure is passed
  524. to the sender's C{unhandledError} method.
  525. """
  526. self.sender.expectError()
  527. callResult = self.dispatcher.callRemote(Hello, hello=b"world")
  528. callResult.addCallback(lambda result: 1 // 0)
  529. self.dispatcher.ampBoxReceived(
  530. amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"})
  531. )
  532. self._localCallbackErrorLoggingTest(callResult)
  533. def test_callRemoteErrorLocalCallbackErrorLogging(self):
  534. """
  535. Like L{test_callRemoteSuccessLocalCallbackErrorLogging}, but for the
  536. case where the L{Deferred} returned by C{callRemote} fails.
  537. """
  538. self.sender.expectError()
  539. callResult = self.dispatcher.callRemote(Hello, hello=b"world")
  540. callResult.addErrback(lambda result: 1 // 0)
  541. self.dispatcher.ampBoxReceived(
  542. amp.AmpBox(
  543. {
  544. b"_error": b"1",
  545. b"_error_code": b"bugs",
  546. b"_error_description": b"stuff",
  547. }
  548. )
  549. )
  550. self._localCallbackErrorLoggingTest(callResult)
  551. class SimpleGreeting(amp.Command):
  552. """
  553. A very simple greeting command that uses a few basic argument types.
  554. """
  555. commandName = b"simple"
  556. arguments = [(b"greeting", amp.Unicode()), (b"cookie", amp.Integer())]
  557. response = [(b"cookieplus", amp.Integer())]
  558. class TestLocator(amp.CommandLocator):
  559. """
  560. A locator which implements a responder to the 'simple' command.
  561. """
  562. def __init__(self):
  563. self.greetings = []
  564. def greetingResponder(self, greeting, cookie):
  565. self.greetings.append((greeting, cookie))
  566. return dict(cookieplus=cookie + 3)
  567. greetingResponder = SimpleGreeting.responder(greetingResponder)
  568. class OverridingLocator(TestLocator):
  569. """
  570. A locator which overrides the responder to the 'simple' command.
  571. """
  572. def greetingResponder(self, greeting, cookie):
  573. """
  574. Return a different cookieplus than L{TestLocator.greetingResponder}.
  575. """
  576. self.greetings.append((greeting, cookie))
  577. return dict(cookieplus=cookie + 4)
  578. greetingResponder = SimpleGreeting.responder(greetingResponder)
  579. class InheritingLocator(OverridingLocator):
  580. """
  581. This locator should inherit the responder from L{OverridingLocator}.
  582. """
  583. class OverrideLocatorAMP(amp.AMP):
  584. def __init__(self):
  585. amp.AMP.__init__(self)
  586. self.customResponder = object()
  587. self.expectations = {b"custom": self.customResponder}
  588. self.greetings = []
  589. def lookupFunction(self, name):
  590. """
  591. Override the deprecated lookupFunction function.
  592. """
  593. if name in self.expectations:
  594. result = self.expectations[name]
  595. return result
  596. else:
  597. return super().lookupFunction(name)
  598. def greetingResponder(self, greeting, cookie):
  599. self.greetings.append((greeting, cookie))
  600. return dict(cookieplus=cookie + 3)
  601. greetingResponder = SimpleGreeting.responder(greetingResponder)
  602. class CommandLocatorTests(TestCase):
  603. """
  604. The CommandLocator should enable users to specify responders to commands as
  605. functions that take structured objects, annotated with metadata.
  606. """
  607. def _checkSimpleGreeting(self, locatorClass, expected):
  608. """
  609. Check that a locator of type C{locatorClass} finds a responder
  610. for command named I{simple} and that the found responder answers
  611. with the C{expected} result to a C{SimpleGreeting<"ni hao", 5>}
  612. command.
  613. """
  614. locator = locatorClass()
  615. responderCallable = locator.locateResponder(b"simple")
  616. result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5"))
  617. def done(values):
  618. self.assertEqual(values, amp.AmpBox(cookieplus=b"%d" % (expected,)))
  619. return result.addCallback(done)
  620. def test_responderDecorator(self):
  621. """
  622. A method on a L{CommandLocator} subclass decorated with a L{Command}
  623. subclass's L{responder} decorator should be returned from
  624. locateResponder, wrapped in logic to serialize and deserialize its
  625. arguments.
  626. """
  627. return self._checkSimpleGreeting(TestLocator, 8)
  628. def test_responderOverriding(self):
  629. """
  630. L{CommandLocator} subclasses can override a responder inherited from
  631. a base class by using the L{Command.responder} decorator to register
  632. a new responder method.
  633. """
  634. return self._checkSimpleGreeting(OverridingLocator, 9)
  635. def test_responderInheritance(self):
  636. """
  637. Responder lookup follows the same rules as normal method lookup
  638. rules, particularly with respect to inheritance.
  639. """
  640. return self._checkSimpleGreeting(InheritingLocator, 9)
  641. def test_lookupFunctionDeprecatedOverride(self):
  642. """
  643. Subclasses which override locateResponder under its old name,
  644. lookupFunction, should have the override invoked instead. (This tests
  645. an AMP subclass, because in the version of the code that could invoke
  646. this deprecated code path, there was no L{CommandLocator}.)
  647. """
  648. locator = OverrideLocatorAMP()
  649. customResponderObject = self.assertWarns(
  650. PendingDeprecationWarning,
  651. "Override locateResponder, not lookupFunction.",
  652. __file__,
  653. lambda: locator.locateResponder(b"custom"),
  654. )
  655. self.assertEqual(locator.customResponder, customResponderObject)
  656. # Make sure upcalling works too
  657. normalResponderObject = self.assertWarns(
  658. PendingDeprecationWarning,
  659. "Override locateResponder, not lookupFunction.",
  660. __file__,
  661. lambda: locator.locateResponder(b"simple"),
  662. )
  663. result = normalResponderObject(amp.Box(greeting=b"ni hao", cookie=b"5"))
  664. def done(values):
  665. self.assertEqual(values, amp.AmpBox(cookieplus=b"8"))
  666. return result.addCallback(done)
  667. def test_lookupFunctionDeprecatedInvoke(self):
  668. """
  669. Invoking locateResponder under its old name, lookupFunction, should
  670. emit a deprecation warning, but do the same thing.
  671. """
  672. locator = TestLocator()
  673. responderCallable = self.assertWarns(
  674. PendingDeprecationWarning,
  675. "Call locateResponder, not lookupFunction.",
  676. __file__,
  677. lambda: locator.lookupFunction(b"simple"),
  678. )
  679. result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5"))
  680. def done(values):
  681. self.assertEqual(values, amp.AmpBox(cookieplus=b"8"))
  682. return result.addCallback(done)
  683. SWITCH_CLIENT_DATA = b"Success!"
  684. SWITCH_SERVER_DATA = b"No, really. Success."
  685. class BinaryProtocolTests(TestCase):
  686. """
  687. Tests for L{amp.BinaryBoxProtocol}.
  688. @ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender}
  689. which was passed to it.
  690. """
  691. def setUp(self):
  692. """
  693. Keep track of all boxes received by this test in its capacity as an
  694. L{IBoxReceiver} implementor.
  695. """
  696. self.boxes = []
  697. self.data = []
  698. def startReceivingBoxes(self, sender):
  699. """
  700. Implement L{IBoxReceiver.startReceivingBoxes} to just remember the
  701. value passed in.
  702. """
  703. self._boxSender = sender
  704. def ampBoxReceived(self, box):
  705. """
  706. A box was received by the protocol.
  707. """
  708. self.boxes.append(box)
  709. stopReason = None
  710. def stopReceivingBoxes(self, reason):
  711. """
  712. Record the reason that we stopped receiving boxes.
  713. """
  714. self.stopReason = reason
  715. # fake ITransport
  716. def getPeer(self):
  717. return "no peer"
  718. def getHost(self):
  719. return "no host"
  720. def write(self, data):
  721. self.assertIsInstance(data, bytes)
  722. self.data.append(data)
  723. def test_startReceivingBoxes(self):
  724. """
  725. When L{amp.BinaryBoxProtocol} is connected to a transport, it calls
  726. C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the
  727. L{IBoxSender} parameter.
  728. """
  729. protocol = amp.BinaryBoxProtocol(self)
  730. protocol.makeConnection(None)
  731. self.assertIs(self._boxSender, protocol)
  732. def test_sendBoxInStartReceivingBoxes(self):
  733. """
  734. The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is
  735. connected to a transport can call C{sendBox} on the L{IBoxSender}
  736. passed to it before C{startReceivingBoxes} returns and have that box
  737. sent.
  738. """
  739. class SynchronouslySendingReceiver:
  740. def startReceivingBoxes(self, sender):
  741. sender.sendBox(amp.Box({b"foo": b"bar"}))
  742. transport = StringTransport()
  743. protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver())
  744. protocol.makeConnection(transport)
  745. self.assertEqual(transport.value(), b"\x00\x03foo\x00\x03bar\x00\x00")
  746. def test_receiveBoxStateMachine(self):
  747. """
  748. When a binary box protocol receives:
  749. * a key
  750. * a value
  751. * an empty string
  752. it should emit a box and send it to its boxReceiver.
  753. """
  754. a = amp.BinaryBoxProtocol(self)
  755. a.stringReceived(b"hello")
  756. a.stringReceived(b"world")
  757. a.stringReceived(b"")
  758. self.assertEqual(self.boxes, [amp.AmpBox(hello=b"world")])
  759. def test_firstBoxFirstKeyExcessiveLength(self):
  760. """
  761. L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
  762. the first a key it receives is larger than 255.
  763. """
  764. transport = StringTransport()
  765. protocol = amp.BinaryBoxProtocol(self)
  766. protocol.makeConnection(transport)
  767. protocol.dataReceived(b"\x01\x00")
  768. self.assertTrue(transport.disconnecting)
  769. def test_firstBoxSubsequentKeyExcessiveLength(self):
  770. """
  771. L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
  772. a subsequent key in the first box it receives is larger than 255.
  773. """
  774. transport = StringTransport()
  775. protocol = amp.BinaryBoxProtocol(self)
  776. protocol.makeConnection(transport)
  777. protocol.dataReceived(b"\x00\x01k\x00\x01v")
  778. self.assertFalse(transport.disconnecting)
  779. protocol.dataReceived(b"\x01\x00")
  780. self.assertTrue(transport.disconnecting)
  781. def test_subsequentBoxFirstKeyExcessiveLength(self):
  782. """
  783. L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
  784. the first key in a subsequent box it receives is larger than 255.
  785. """
  786. transport = StringTransport()
  787. protocol = amp.BinaryBoxProtocol(self)
  788. protocol.makeConnection(transport)
  789. protocol.dataReceived(b"\x00\x01k\x00\x01v\x00\x00")
  790. self.assertFalse(transport.disconnecting)
  791. protocol.dataReceived(b"\x01\x00")
  792. self.assertTrue(transport.disconnecting)
  793. def test_excessiveKeyFailure(self):
  794. """
  795. If L{amp.BinaryBoxProtocol} disconnects because it received a key
  796. length prefix which was too large, the L{IBoxReceiver}'s
  797. C{stopReceivingBoxes} method is called with a L{TooLong} failure.
  798. """
  799. protocol = amp.BinaryBoxProtocol(self)
  800. protocol.makeConnection(StringTransport())
  801. protocol.dataReceived(b"\x01\x00")
  802. protocol.connectionLost(
  803. Failure(error.ConnectionDone("simulated connection done"))
  804. )
  805. self.stopReason.trap(amp.TooLong)
  806. self.assertTrue(self.stopReason.value.isKey)
  807. self.assertFalse(self.stopReason.value.isLocal)
  808. self.assertIsNone(self.stopReason.value.value)
  809. self.assertIsNone(self.stopReason.value.keyName)
  810. def test_unhandledErrorWithTransport(self):
  811. """
  812. L{amp.BinaryBoxProtocol.unhandledError} logs the failure passed to it
  813. and disconnects its transport.
  814. """
  815. transport = StringTransport()
  816. protocol = amp.BinaryBoxProtocol(self)
  817. protocol.makeConnection(transport)
  818. protocol.unhandledError(Failure(RuntimeError("Fake error")))
  819. self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))
  820. self.assertTrue(transport.disconnecting)
  821. def test_unhandledErrorWithoutTransport(self):
  822. """
  823. L{amp.BinaryBoxProtocol.unhandledError} completes without error when
  824. there is no associated transport.
  825. """
  826. protocol = amp.BinaryBoxProtocol(self)
  827. protocol.makeConnection(StringTransport())
  828. protocol.connectionLost(Failure(Exception("Simulated")))
  829. protocol.unhandledError(Failure(RuntimeError("Fake error")))
  830. self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))
  831. def test_receiveBoxData(self):
  832. """
  833. When a binary box protocol receives the serialized form of an AMP box,
  834. it should emit a similar box to its boxReceiver.
  835. """
  836. a = amp.BinaryBoxProtocol(self)
  837. a.dataReceived(
  838. amp.Box(
  839. {b"testKey": b"valueTest", b"anotherKey": b"anotherValue"}
  840. ).serialize()
  841. )
  842. self.assertEqual(
  843. self.boxes,
  844. [amp.Box({b"testKey": b"valueTest", b"anotherKey": b"anotherValue"})],
  845. )
  846. def test_receiveLongerBoxData(self):
  847. """
  848. An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with
  849. values of up to (2 ** 16 - 1) bytes.
  850. """
  851. length = 2 ** 16 - 1
  852. value = b"x" * length
  853. transport = StringTransport()
  854. protocol = amp.BinaryBoxProtocol(self)
  855. protocol.makeConnection(transport)
  856. protocol.dataReceived(amp.Box({"k": value}).serialize())
  857. self.assertEqual(self.boxes, [amp.Box({"k": value})])
  858. self.assertFalse(transport.disconnecting)
  859. def test_sendBox(self):
  860. """
  861. When a binary box protocol sends a box, it should emit the serialized
  862. bytes of that box to its transport.
  863. """
  864. a = amp.BinaryBoxProtocol(self)
  865. a.makeConnection(self)
  866. aBox = amp.Box({b"testKey": b"valueTest", b"someData": b"hello"})
  867. a.makeConnection(self)
  868. a.sendBox(aBox)
  869. self.assertEqual(b"".join(self.data), aBox.serialize())
  870. def test_connectionLostStopSendingBoxes(self):
  871. """
  872. When a binary box protocol loses its connection, it should notify its
  873. box receiver that it has stopped receiving boxes.
  874. """
  875. a = amp.BinaryBoxProtocol(self)
  876. a.makeConnection(self)
  877. connectionFailure = Failure(RuntimeError())
  878. a.connectionLost(connectionFailure)
  879. self.assertIs(self.stopReason, connectionFailure)
  880. def test_protocolSwitch(self):
  881. """
  882. L{BinaryBoxProtocol} has the capacity to switch to a different protocol
  883. on a box boundary. When a protocol is in the process of switching, it
  884. cannot receive traffic.
  885. """
  886. otherProto = TestProto(None, b"outgoing data")
  887. test = self
  888. class SwitchyReceiver:
  889. switched = False
  890. def startReceivingBoxes(self, sender):
  891. pass
  892. def ampBoxReceived(self, box):
  893. test.assertFalse(self.switched, "Should only receive one box!")
  894. self.switched = True
  895. a._lockForSwitch()
  896. a._switchTo(otherProto)
  897. a = amp.BinaryBoxProtocol(SwitchyReceiver())
  898. anyOldBox = amp.Box({b"include": b"lots", b"of": b"data"})
  899. a.makeConnection(self)
  900. # Include a 0-length box at the beginning of the next protocol's data,
  901. # to make sure that AMP doesn't eat the data or try to deliver extra
  902. # boxes either...
  903. moreThanOneBox = anyOldBox.serialize() + b"\x00\x00Hello, world!"
  904. a.dataReceived(moreThanOneBox)
  905. self.assertIs(otherProto.transport, self)
  906. self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!")
  907. self.assertEqual(self.data, [b"outgoing data"])
  908. a.dataReceived(b"more data")
  909. self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!more data")
  910. self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox)
  911. def test_protocolSwitchEmptyBuffer(self):
  912. """
  913. After switching to a different protocol, if no extra bytes beyond
  914. the switch box were delivered, an empty string is not passed to the
  915. switched protocol's C{dataReceived} method.
  916. """
  917. a = amp.BinaryBoxProtocol(self)
  918. a.makeConnection(self)
  919. otherProto = TestProto(None, b"")
  920. a._switchTo(otherProto)
  921. self.assertEqual(otherProto.data, [])
  922. def test_protocolSwitchInvalidStates(self):
  923. """
  924. In order to make sure the protocol never gets any invalid data sent
  925. into the middle of a box, it must be locked for switching before it is
  926. switched. It can only be unlocked if the switch failed, and attempting
  927. to send a box while it is locked should raise an exception.
  928. """
  929. a = amp.BinaryBoxProtocol(self)
  930. a.makeConnection(self)
  931. sampleBox = amp.Box({b"some": b"data"})
  932. a._lockForSwitch()
  933. self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox)
  934. a._unlockFromSwitch()
  935. a.sendBox(sampleBox)
  936. self.assertEqual(b"".join(self.data), sampleBox.serialize())
  937. a._lockForSwitch()
  938. otherProto = TestProto(None, b"outgoing data")
  939. a._switchTo(otherProto)
  940. self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch)
  941. def test_protocolSwitchLoseConnection(self):
  942. """
  943. When the protocol is switched, it should notify its nested protocol of
  944. disconnection.
  945. """
  946. class Loser(protocol.Protocol):
  947. reason = None
  948. def connectionLost(self, reason):
  949. self.reason = reason
  950. connectionLoser = Loser()
  951. a = amp.BinaryBoxProtocol(self)
  952. a.makeConnection(self)
  953. a._lockForSwitch()
  954. a._switchTo(connectionLoser)
  955. connectionFailure = Failure(RuntimeError())
  956. a.connectionLost(connectionFailure)
  957. self.assertEqual(connectionLoser.reason, connectionFailure)
  958. def test_protocolSwitchLoseClientConnection(self):
  959. """
  960. When the protocol is switched, it should notify its nested client
  961. protocol factory of disconnection.
  962. """
  963. class ClientLoser:
  964. reason = None
  965. def clientConnectionLost(self, connector, reason):
  966. self.reason = reason
  967. a = amp.BinaryBoxProtocol(self)
  968. connectionLoser = protocol.Protocol()
  969. clientLoser = ClientLoser()
  970. a.makeConnection(self)
  971. a._lockForSwitch()
  972. a._switchTo(connectionLoser, clientLoser)
  973. connectionFailure = Failure(RuntimeError())
  974. a.connectionLost(connectionFailure)
  975. self.assertEqual(clientLoser.reason, connectionFailure)
  976. class AMPTests(TestCase):
  977. def test_interfaceDeclarations(self):
  978. """
  979. The classes in the amp module ought to implement the interfaces that
  980. are declared for their benefit.
  981. """
  982. for interface, implementation in [
  983. (amp.IBoxSender, amp.BinaryBoxProtocol),
  984. (amp.IBoxReceiver, amp.BoxDispatcher),
  985. (amp.IResponderLocator, amp.CommandLocator),
  986. (amp.IResponderLocator, amp.SimpleStringLocator),
  987. (amp.IBoxSender, amp.AMP),
  988. (amp.IBoxReceiver, amp.AMP),
  989. (amp.IResponderLocator, amp.AMP),
  990. ]:
  991. self.assertTrue(
  992. interface.implementedBy(implementation),
  993. f"{implementation} does not implements({interface})",
  994. )
  995. def test_helloWorld(self):
  996. """
  997. Verify that a simple command can be sent and its response received with
  998. the simple low-level string-based API.
  999. """
  1000. c, s, p = connectedServerAndClient()
  1001. L = []
  1002. HELLO = b"world"
  1003. c.sendHello(HELLO).addCallback(L.append)
  1004. p.flush()
  1005. self.assertEqual(L[0][b"hello"], HELLO)
  1006. def test_wireFormatRoundTrip(self):
  1007. """
  1008. Verify that mixed-case, underscored and dashed arguments are mapped to
  1009. their python names properly.
  1010. """
  1011. c, s, p = connectedServerAndClient()
  1012. L = []
  1013. HELLO = b"world"
  1014. c.sendHello(HELLO).addCallback(L.append)
  1015. p.flush()
  1016. self.assertEqual(L[0][b"hello"], HELLO)
  1017. def test_helloWorldUnicode(self):
  1018. """
  1019. Verify that unicode arguments can be encoded and decoded.
  1020. """
  1021. c, s, p = connectedServerAndClient(
  1022. ServerClass=SimpleSymmetricCommandProtocol,
  1023. ClientClass=SimpleSymmetricCommandProtocol,
  1024. )
  1025. L = []
  1026. HELLO = b"world"
  1027. HELLO_UNICODE = "wor\u1234ld"
  1028. c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
  1029. p.flush()
  1030. self.assertEqual(L[0]["hello"], HELLO)
  1031. self.assertEqual(L[0]["Print"], HELLO_UNICODE)
  1032. def test_callRemoteStringRequiresAnswerFalse(self):
  1033. """
  1034. L{BoxDispatcher.callRemoteString} returns L{None} if C{requiresAnswer}
  1035. is C{False}.
  1036. """
  1037. c, s, p = connectedServerAndClient()
  1038. ret = c.callRemoteString(b"WTF", requiresAnswer=False)
  1039. self.assertIsNone(ret)
  1040. def test_unknownCommandLow(self):
  1041. """
  1042. Verify that unknown commands using low-level APIs will be rejected with an
  1043. error, but will NOT terminate the connection.
  1044. """
  1045. c, s, p = connectedServerAndClient()
  1046. L = []
  1047. def clearAndAdd(e):
  1048. """
  1049. You can't propagate the error...
  1050. """
  1051. e.trap(amp.UnhandledCommand)
  1052. return "OK"
  1053. c.callRemoteString(b"WTF").addErrback(clearAndAdd).addCallback(L.append)
  1054. p.flush()
  1055. self.assertEqual(L.pop(), "OK")
  1056. HELLO = b"world"
  1057. c.sendHello(HELLO).addCallback(L.append)
  1058. p.flush()
  1059. self.assertEqual(L[0][b"hello"], HELLO)
  1060. def test_unknownCommandHigh(self):
  1061. """
  1062. Verify that unknown commands using high-level APIs will be rejected with an
  1063. error, but will NOT terminate the connection.
  1064. """
  1065. c, s, p = connectedServerAndClient()
  1066. L = []
  1067. def clearAndAdd(e):
  1068. """
  1069. You can't propagate the error...
  1070. """
  1071. e.trap(amp.UnhandledCommand)
  1072. return "OK"
  1073. c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
  1074. p.flush()
  1075. self.assertEqual(L.pop(), "OK")
  1076. HELLO = b"world"
  1077. c.sendHello(HELLO).addCallback(L.append)
  1078. p.flush()
  1079. self.assertEqual(L[0][b"hello"], HELLO)
  1080. def test_brokenReturnValue(self):
  1081. """
  1082. It can be very confusing if you write some code which responds to a
  1083. command, but gets the return value wrong. Most commonly you end up
  1084. returning None instead of a dictionary.
  1085. Verify that if that happens, the framework logs a useful error.
  1086. """
  1087. L = []
  1088. SimpleSymmetricCommandProtocol().dispatchCommand(
  1089. amp.AmpBox(_command=BrokenReturn.commandName)
  1090. ).addErrback(L.append)
  1091. L[0].trap(amp.BadLocalReturn)
  1092. self.failUnlessIn("None", repr(L[0].value))
  1093. def test_unknownArgument(self):
  1094. """
  1095. Verify that unknown arguments are ignored, and not passed to a Python
  1096. function which can't accept them.
  1097. """
  1098. c, s, p = connectedServerAndClient(
  1099. ServerClass=SimpleSymmetricCommandProtocol,
  1100. ClientClass=SimpleSymmetricCommandProtocol,
  1101. )
  1102. L = []
  1103. HELLO = b"world"
  1104. # c.sendHello(HELLO).addCallback(L.append)
  1105. c.callRemote(
  1106. FutureHello, hello=HELLO, bonus=b"I'm not in the book!"
  1107. ).addCallback(L.append)
  1108. p.flush()
  1109. self.assertEqual(L[0]["hello"], HELLO)
  1110. def test_simpleReprs(self):
  1111. """
  1112. Verify that the various Box objects repr properly, for debugging.
  1113. """
  1114. self.assertEqual(type(repr(amp._SwitchBox("a"))), str)
  1115. self.assertEqual(type(repr(amp.QuitBox())), str)
  1116. self.assertEqual(type(repr(amp.AmpBox())), str)
  1117. self.assertIn("AmpBox", repr(amp.AmpBox()))
  1118. def test_innerProtocolInRepr(self):
  1119. """
  1120. Verify that L{AMP} objects output their innerProtocol when set.
  1121. """
  1122. otherProto = TestProto(None, b"outgoing data")
  1123. a = amp.AMP()
  1124. a.innerProtocol = otherProto
  1125. self.assertEqual(
  1126. repr(a),
  1127. "<AMP inner <TestProto #%d> at 0x%x>" % (otherProto.instanceId, id(a)),
  1128. )
  1129. def test_innerProtocolNotInRepr(self):
  1130. """
  1131. Verify that L{AMP} objects do not output 'inner' when no innerProtocol
  1132. is set.
  1133. """
  1134. a = amp.AMP()
  1135. self.assertEqual(repr(a), f"<AMP at 0x{id(a):x}>")
  1136. @skipIf(skipSSL, "SSL not available")
  1137. def test_simpleSSLRepr(self):
  1138. """
  1139. L{amp._TLSBox.__repr__} returns a string.
  1140. """
  1141. self.assertEqual(type(repr(amp._TLSBox())), str)
  1142. def test_keyTooLong(self):
  1143. """
  1144. Verify that a key that is too long will immediately raise a synchronous
  1145. exception.
  1146. """
  1147. c, s, p = connectedServerAndClient()
  1148. x = "H" * (0xFF + 1)
  1149. tl = self.assertRaises(amp.TooLong, c.callRemoteString, b"Hello", **{x: b"hi"})
  1150. self.assertTrue(tl.isKey)
  1151. self.assertTrue(tl.isLocal)
  1152. self.assertIsNone(tl.keyName)
  1153. self.assertEqual(tl.value, x.encode("ascii"))
  1154. self.assertIn(str(len(x)), repr(tl))
  1155. self.assertIn("key", repr(tl))
  1156. def test_valueTooLong(self):
  1157. """
  1158. Verify that attempting to send value longer than 64k will immediately
  1159. raise an exception.
  1160. """
  1161. c, s, p = connectedServerAndClient()
  1162. x = b"H" * (0xFFFF + 1)
  1163. tl = self.assertRaises(amp.TooLong, c.sendHello, x)
  1164. p.flush()
  1165. self.assertFalse(tl.isKey)
  1166. self.assertTrue(tl.isLocal)
  1167. self.assertEqual(tl.keyName, b"hello")
  1168. self.failUnlessIdentical(tl.value, x)
  1169. self.assertIn(str(len(x)), repr(tl))
  1170. self.assertIn("value", repr(tl))
  1171. self.assertIn("hello", repr(tl))
  1172. def test_helloWorldCommand(self):
  1173. """
  1174. Verify that a simple command can be sent and its response received with
  1175. the high-level value parsing API.
  1176. """
  1177. c, s, p = connectedServerAndClient(
  1178. ServerClass=SimpleSymmetricCommandProtocol,
  1179. ClientClass=SimpleSymmetricCommandProtocol,
  1180. )
  1181. L = []
  1182. HELLO = b"world"
  1183. c.sendHello(HELLO).addCallback(L.append)
  1184. p.flush()
  1185. self.assertEqual(L[0]["hello"], HELLO)
  1186. def test_helloErrorHandling(self):
  1187. """
  1188. Verify that if a known error type is raised and handled, it will be
  1189. properly relayed to the other end of the connection and translated into
  1190. an exception, and no error will be logged.
  1191. """
  1192. L = []
  1193. c, s, p = connectedServerAndClient(
  1194. ServerClass=SimpleSymmetricCommandProtocol,
  1195. ClientClass=SimpleSymmetricCommandProtocol,
  1196. )
  1197. HELLO = b"fuck you"
  1198. c.sendHello(HELLO).addErrback(L.append)
  1199. p.flush()
  1200. L[0].trap(UnfriendlyGreeting)
  1201. self.assertEqual(str(L[0].value), "Don't be a dick.")
  1202. def test_helloFatalErrorHandling(self):
  1203. """
  1204. Verify that if a known, fatal error type is raised and handled, it will
  1205. be properly relayed to the other end of the connection and translated
  1206. into an exception, no error will be logged, and the connection will be
  1207. terminated.
  1208. """
  1209. L = []
  1210. c, s, p = connectedServerAndClient(
  1211. ServerClass=SimpleSymmetricCommandProtocol,
  1212. ClientClass=SimpleSymmetricCommandProtocol,
  1213. )
  1214. HELLO = b"die"
  1215. c.sendHello(HELLO).addErrback(L.append)
  1216. p.flush()
  1217. L.pop().trap(DeathThreat)
  1218. c.sendHello(HELLO).addErrback(L.append)
  1219. p.flush()
  1220. L.pop().trap(error.ConnectionDone)
  1221. def test_helloNoErrorHandling(self):
  1222. """
  1223. Verify that if an unknown error type is raised, it will be relayed to
  1224. the other end of the connection and translated into an exception, it
  1225. will be logged, and then the connection will be dropped.
  1226. """
  1227. L = []
  1228. c, s, p = connectedServerAndClient(
  1229. ServerClass=SimpleSymmetricCommandProtocol,
  1230. ClientClass=SimpleSymmetricCommandProtocol,
  1231. )
  1232. HELLO = THING_I_DONT_UNDERSTAND
  1233. c.sendHello(HELLO).addErrback(L.append)
  1234. p.flush()
  1235. ure = L.pop()
  1236. ure.trap(amp.UnknownRemoteError)
  1237. c.sendHello(HELLO).addErrback(L.append)
  1238. cl = L.pop()
  1239. cl.trap(error.ConnectionDone)
  1240. # The exception should have been logged.
  1241. self.assertTrue(self.flushLoggedErrors(ThingIDontUnderstandError))
  1242. def test_lateAnswer(self):
  1243. """
  1244. Verify that a command that does not get answered until after the
  1245. connection terminates will not cause any errors.
  1246. """
  1247. c, s, p = connectedServerAndClient(
  1248. ServerClass=SimpleSymmetricCommandProtocol,
  1249. ClientClass=SimpleSymmetricCommandProtocol,
  1250. )
  1251. L = []
  1252. c.callRemote(WaitForever).addErrback(L.append)
  1253. p.flush()
  1254. self.assertEqual(L, [])
  1255. s.transport.loseConnection()
  1256. p.flush()
  1257. L.pop().trap(error.ConnectionDone)
  1258. # Just make sure that it doesn't error...
  1259. s.waiting.callback({})
  1260. return s.waiting
  1261. def test_requiresNoAnswer(self):
  1262. """
  1263. Verify that a command that requires no answer is run.
  1264. """
  1265. c, s, p = connectedServerAndClient(
  1266. ServerClass=SimpleSymmetricCommandProtocol,
  1267. ClientClass=SimpleSymmetricCommandProtocol,
  1268. )
  1269. HELLO = b"world"
  1270. c.callRemote(NoAnswerHello, hello=HELLO)
  1271. p.flush()
  1272. self.assertTrue(s.greeted)
  1273. def test_requiresNoAnswerFail(self):
  1274. """
  1275. Verify that commands sent after a failed no-answer request do not complete.
  1276. """
  1277. L = []
  1278. c, s, p = connectedServerAndClient(
  1279. ServerClass=SimpleSymmetricCommandProtocol,
  1280. ClientClass=SimpleSymmetricCommandProtocol,
  1281. )
  1282. HELLO = b"fuck you"
  1283. c.callRemote(NoAnswerHello, hello=HELLO)
  1284. p.flush()
  1285. # This should be logged locally.
  1286. self.assertTrue(self.flushLoggedErrors(amp.RemoteAmpError))
  1287. HELLO = b"world"
  1288. c.callRemote(Hello, hello=HELLO).addErrback(L.append)
  1289. p.flush()
  1290. L.pop().trap(error.ConnectionDone)
  1291. self.assertFalse(s.greeted)
  1292. def test_requiresNoAnswerAfterFail(self):
  1293. """
  1294. No-answer commands sent after the connection has been torn down do not
  1295. return a L{Deferred}.
  1296. """
  1297. c, s, p = connectedServerAndClient(
  1298. ServerClass=SimpleSymmetricCommandProtocol,
  1299. ClientClass=SimpleSymmetricCommandProtocol,
  1300. )
  1301. c.transport.loseConnection()
  1302. p.flush()
  1303. result = c.callRemote(NoAnswerHello, hello=b"ignored")
  1304. self.assertIs(result, None)
  1305. def test_noAnswerResponderBadAnswer(self):
  1306. """
  1307. Verify that responders of requiresAnswer=False commands have to return
  1308. a dictionary anyway.
  1309. (requiresAnswer is a hint from the _client_ - the server may be called
  1310. upon to answer commands in any case, if the client wants to know when
  1311. they complete.)
  1312. """
  1313. c, s, p = connectedServerAndClient(
  1314. ServerClass=BadNoAnswerCommandProtocol,
  1315. ClientClass=SimpleSymmetricCommandProtocol,
  1316. )
  1317. c.callRemote(NoAnswerHello, hello=b"hello")
  1318. p.flush()
  1319. le = self.flushLoggedErrors(amp.BadLocalReturn)
  1320. self.assertEqual(len(le), 1)
  1321. def test_noAnswerResponderAskedForAnswer(self):
  1322. """
  1323. Verify that responders with requiresAnswer=False will actually respond
  1324. if the client sets requiresAnswer=True. In other words, verify that
  1325. requiresAnswer is a hint honored only by the client.
  1326. """
  1327. c, s, p = connectedServerAndClient(
  1328. ServerClass=NoAnswerCommandProtocol,
  1329. ClientClass=SimpleSymmetricCommandProtocol,
  1330. )
  1331. L = []
  1332. c.callRemote(Hello, hello=b"Hello!").addCallback(L.append)
  1333. p.flush()
  1334. self.assertEqual(len(L), 1)
  1335. self.assertEqual(
  1336. L, [dict(hello=b"Hello!-noanswer", Print=None)]
  1337. ) # Optional response argument
  1338. def test_ampListCommand(self):
  1339. """
  1340. Test encoding of an argument that uses the AmpList encoding.
  1341. """
  1342. c, s, p = connectedServerAndClient(
  1343. ServerClass=SimpleSymmetricCommandProtocol,
  1344. ClientClass=SimpleSymmetricCommandProtocol,
  1345. )
  1346. L = []
  1347. c.callRemote(GetList, length=10).addCallback(L.append)
  1348. p.flush()
  1349. values = L.pop().get("body")
  1350. self.assertEqual(values, [{"x": 1}] * 10)
  1351. def test_optionalAmpListOmitted(self):
  1352. """
  1353. Sending a command with an omitted AmpList argument that is
  1354. designated as optional does not raise an InvalidSignature error.
  1355. """
  1356. c, s, p = connectedServerAndClient(
  1357. ServerClass=SimpleSymmetricCommandProtocol,
  1358. ClientClass=SimpleSymmetricCommandProtocol,
  1359. )
  1360. L = []
  1361. c.callRemote(DontRejectMe, magicWord="please").addCallback(L.append)
  1362. p.flush()
  1363. response = L.pop().get("response")
  1364. self.assertEqual(response, "list omitted")
  1365. def test_optionalAmpListPresent(self):
  1366. """
  1367. Sanity check that optional AmpList arguments are processed normally.
  1368. """
  1369. c, s, p = connectedServerAndClient(
  1370. ServerClass=SimpleSymmetricCommandProtocol,
  1371. ClientClass=SimpleSymmetricCommandProtocol,
  1372. )
  1373. L = []
  1374. c.callRemote(
  1375. DontRejectMe, magicWord="please", list=[{"name": "foo"}]
  1376. ).addCallback(L.append)
  1377. p.flush()
  1378. response = L.pop().get("response")
  1379. self.assertEqual(response, "foo accepted")
  1380. def test_failEarlyOnArgSending(self):
  1381. """
  1382. Verify that if we pass an invalid argument list (omitting an argument),
  1383. an exception will be raised.
  1384. """
  1385. self.assertRaises(amp.InvalidSignature, Hello)
  1386. def test_doubleProtocolSwitch(self):
  1387. """
  1388. As a debugging aid, a protocol system should raise a
  1389. L{ProtocolSwitched} exception when asked to switch a protocol that is
  1390. already switched.
  1391. """
  1392. serverDeferred = defer.Deferred()
  1393. serverProto = SimpleSymmetricCommandProtocol(serverDeferred)
  1394. clientDeferred = defer.Deferred()
  1395. clientProto = SimpleSymmetricCommandProtocol(clientDeferred)
  1396. c, s, p = connectedServerAndClient(
  1397. ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
  1398. )
  1399. def switched(result):
  1400. self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol)
  1401. self.testSucceeded = True
  1402. c.switchToTestProtocol().addCallback(switched)
  1403. p.flush()
  1404. self.assertTrue(self.testSucceeded)
  1405. def test_protocolSwitch(
  1406. self,
  1407. switcher=SimpleSymmetricCommandProtocol,
  1408. spuriousTraffic=False,
  1409. spuriousError=False,
  1410. ):
  1411. """
  1412. Verify that it is possible to switch to another protocol mid-connection and
  1413. send data to it successfully.
  1414. """
  1415. self.testSucceeded = False
  1416. serverDeferred = defer.Deferred()
  1417. serverProto = switcher(serverDeferred)
  1418. clientDeferred = defer.Deferred()
  1419. clientProto = switcher(clientDeferred)
  1420. c, s, p = connectedServerAndClient(
  1421. ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
  1422. )
  1423. if spuriousTraffic:
  1424. wfdr = [] # remote
  1425. c.callRemote(WaitForever).addErrback(wfdr.append)
  1426. switchDeferred = c.switchToTestProtocol()
  1427. if spuriousTraffic:
  1428. self.assertRaises(amp.ProtocolSwitched, c.sendHello, b"world")
  1429. def cbConnsLost(info):
  1430. ((serverSuccess, serverData), (clientSuccess, clientData)) = info
  1431. self.assertTrue(serverSuccess)
  1432. self.assertTrue(clientSuccess)
  1433. self.assertEqual(b"".join(serverData), SWITCH_CLIENT_DATA)
  1434. self.assertEqual(b"".join(clientData), SWITCH_SERVER_DATA)
  1435. self.testSucceeded = True
  1436. def cbSwitch(proto):
  1437. return defer.DeferredList([serverDeferred, clientDeferred]).addCallback(
  1438. cbConnsLost
  1439. )
  1440. switchDeferred.addCallback(cbSwitch)
  1441. p.flush()
  1442. if serverProto.maybeLater is not None:
  1443. serverProto.maybeLater.callback(serverProto.maybeLaterProto)
  1444. p.flush()
  1445. if spuriousTraffic:
  1446. # switch is done here; do this here to make sure that if we're
  1447. # going to corrupt the connection, we do it before it's closed.
  1448. if spuriousError:
  1449. s.waiting.errback(
  1450. amp.RemoteAmpError(
  1451. b"SPURIOUS", "Here's some traffic in the form of an error."
  1452. )
  1453. )
  1454. else:
  1455. s.waiting.callback({})
  1456. p.flush()
  1457. c.transport.loseConnection() # close it
  1458. p.flush()
  1459. self.assertTrue(self.testSucceeded)
  1460. def test_protocolSwitchDeferred(self):
  1461. """
  1462. Verify that protocol-switching even works if the value returned from
  1463. the command that does the switch is deferred.
  1464. """
  1465. return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)
  1466. def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
  1467. """
  1468. Verify that if we try to switch protocols and it fails, the connection
  1469. stays up and we can go back to speaking AMP.
  1470. """
  1471. self.testSucceeded = False
  1472. serverDeferred = defer.Deferred()
  1473. serverProto = switcher(serverDeferred)
  1474. clientDeferred = defer.Deferred()
  1475. clientProto = switcher(clientDeferred)
  1476. c, s, p = connectedServerAndClient(
  1477. ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
  1478. )
  1479. L = []
  1480. c.switchToTestProtocol(fail=True).addErrback(L.append)
  1481. p.flush()
  1482. L.pop().trap(UnknownProtocol)
  1483. self.assertFalse(self.testSucceeded)
  1484. # It's a known error, so let's send a "hello" on the same connection;
  1485. # it should work.
  1486. c.sendHello(b"world").addCallback(L.append)
  1487. p.flush()
  1488. self.assertEqual(L.pop()["hello"], b"world")
  1489. def test_trafficAfterSwitch(self):
  1490. """
  1491. Verify that attempts to send traffic after a switch will not corrupt
  1492. the nested protocol.
  1493. """
  1494. return self.test_protocolSwitch(spuriousTraffic=True)
  1495. def test_errorAfterSwitch(self):
  1496. """
  1497. Returning an error after a protocol switch should record the underlying
  1498. error.
  1499. """
  1500. return self.test_protocolSwitch(spuriousTraffic=True, spuriousError=True)
  1501. def test_quitBoxQuits(self):
  1502. """
  1503. Verify that commands with a responseType of QuitBox will in fact
  1504. terminate the connection.
  1505. """
  1506. c, s, p = connectedServerAndClient(
  1507. ServerClass=SimpleSymmetricCommandProtocol,
  1508. ClientClass=SimpleSymmetricCommandProtocol,
  1509. )
  1510. L = []
  1511. HELLO = b"world"
  1512. GOODBYE = b"everyone"
  1513. c.sendHello(HELLO).addCallback(L.append)
  1514. p.flush()
  1515. self.assertEqual(L.pop()["hello"], HELLO)
  1516. c.callRemote(Goodbye).addCallback(L.append)
  1517. p.flush()
  1518. self.assertEqual(L.pop()["goodbye"], GOODBYE)
  1519. c.sendHello(HELLO).addErrback(L.append)
  1520. L.pop().trap(error.ConnectionDone)
  1521. def test_basicLiteralEmit(self):
  1522. """
  1523. Verify that the command dictionaries for a callRemoteN look correct
  1524. after being serialized and parsed.
  1525. """
  1526. c, s, p = connectedServerAndClient()
  1527. L = []
  1528. s.ampBoxReceived = L.append
  1529. c.callRemote(
  1530. Hello,
  1531. hello=b"hello test",
  1532. mixedCase=b"mixed case arg test",
  1533. dash_arg=b"x",
  1534. underscore_arg=b"y",
  1535. )
  1536. p.flush()
  1537. self.assertEqual(len(L), 1)
  1538. for k, v in [
  1539. (b"_command", Hello.commandName),
  1540. (b"hello", b"hello test"),
  1541. (b"mixedCase", b"mixed case arg test"),
  1542. (b"dash-arg", b"x"),
  1543. (b"underscore_arg", b"y"),
  1544. ]:
  1545. self.assertEqual(L[-1].pop(k), v)
  1546. L[-1].pop(b"_ask")
  1547. self.assertEqual(L[-1], {})
  1548. def test_basicStructuredEmit(self):
  1549. """
  1550. Verify that a call similar to basicLiteralEmit's is handled properly with
  1551. high-level quoting and passing to Python methods, and that argument
  1552. names are correctly handled.
  1553. """
  1554. L = []
  1555. class StructuredHello(amp.AMP):
  1556. def h(self, *a, **k):
  1557. L.append((a, k))
  1558. return dict(hello=b"aaa")
  1559. Hello.responder(h)
  1560. c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
  1561. c.callRemote(
  1562. Hello,
  1563. hello=b"hello test",
  1564. mixedCase=b"mixed case arg test",
  1565. dash_arg=b"x",
  1566. underscore_arg=b"y",
  1567. ).addCallback(L.append)
  1568. p.flush()
  1569. self.assertEqual(len(L), 2)
  1570. self.assertEqual(
  1571. L[0],
  1572. (
  1573. (),
  1574. dict(
  1575. hello=b"hello test",
  1576. mixedCase=b"mixed case arg test",
  1577. dash_arg=b"x",
  1578. underscore_arg=b"y",
  1579. From=s.transport.getPeer(),
  1580. # XXX - should optional arguments just not be passed?
  1581. # passing None seems a little odd, looking at the way it
  1582. # turns out here... -glyph
  1583. Print=None,
  1584. optional=None,
  1585. ),
  1586. ),
  1587. )
  1588. self.assertEqual(L[1], dict(Print=None, hello=b"aaa"))
  1589. class PretendRemoteCertificateAuthority:
  1590. def checkIsPretendRemote(self):
  1591. return True
  1592. class IOSimCert:
  1593. verifyCount = 0
  1594. def options(self, *ign):
  1595. return self
  1596. def iosimVerify(self, otherCert):
  1597. """
  1598. This isn't a real certificate, and wouldn't work on a real socket, but
  1599. iosim specifies a different API so that we don't have to do any crypto
  1600. math to demonstrate that the right functions get called in the right
  1601. places.
  1602. """
  1603. assert otherCert is self
  1604. self.verifyCount += 1
  1605. return True
  1606. class OKCert(IOSimCert):
  1607. def options(self, x):
  1608. assert x.checkIsPretendRemote()
  1609. return self
  1610. class GrumpyCert(IOSimCert):
  1611. def iosimVerify(self, otherCert):
  1612. self.verifyCount += 1
  1613. return False
  1614. class DroppyCert(IOSimCert):
  1615. def __init__(self, toDrop):
  1616. self.toDrop = toDrop
  1617. def iosimVerify(self, otherCert):
  1618. self.verifyCount += 1
  1619. self.toDrop.loseConnection()
  1620. return True
  1621. class SecurableProto(FactoryNotifier):
  1622. factory = None
  1623. def verifyFactory(self):
  1624. return [PretendRemoteCertificateAuthority()]
  1625. def getTLSVars(self):
  1626. cert = self.certFactory()
  1627. verify = self.verifyFactory()
  1628. return dict(tls_localCertificate=cert, tls_verifyAuthorities=verify)
  1629. amp.StartTLS.responder(getTLSVars)
  1630. @skipIf(skipSSL, "SSL not available")
  1631. @skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
  1632. class TLSTests(TestCase):
  1633. def test_startingTLS(self):
  1634. """
  1635. Verify that starting TLS and succeeding at handshaking sends all the
  1636. notifications to all the right places.
  1637. """
  1638. cli, svr, p = connectedServerAndClient(
  1639. ServerClass=SecurableProto, ClientClass=SecurableProto
  1640. )
  1641. okc = OKCert()
  1642. svr.certFactory = lambda: okc
  1643. cli.callRemote(
  1644. amp.StartTLS,
  1645. tls_localCertificate=okc,
  1646. tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
  1647. )
  1648. # let's buffer something to be delivered securely
  1649. L = []
  1650. cli.callRemote(SecuredPing).addCallback(L.append)
  1651. p.flush()
  1652. # once for client once for server
  1653. self.assertEqual(okc.verifyCount, 2)
  1654. L = []
  1655. cli.callRemote(SecuredPing).addCallback(L.append)
  1656. p.flush()
  1657. self.assertEqual(L[0], {"pinged": True})
  1658. def test_startTooManyTimes(self):
  1659. """
  1660. Verify that the protocol will complain if we attempt to renegotiate TLS,
  1661. which we don't support.
  1662. """
  1663. cli, svr, p = connectedServerAndClient(
  1664. ServerClass=SecurableProto, ClientClass=SecurableProto
  1665. )
  1666. okc = OKCert()
  1667. svr.certFactory = lambda: okc
  1668. cli.callRemote(
  1669. amp.StartTLS,
  1670. tls_localCertificate=okc,
  1671. tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
  1672. )
  1673. p.flush()
  1674. cli.noPeerCertificate = True # this is totally fake
  1675. self.assertRaises(
  1676. amp.OnlyOneTLS,
  1677. cli.callRemote,
  1678. amp.StartTLS,
  1679. tls_localCertificate=okc,
  1680. tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
  1681. )
  1682. def test_negotiationFailed(self):
  1683. """
  1684. Verify that starting TLS and failing on both sides at handshaking sends
  1685. notifications to all the right places and terminates the connection.
  1686. """
  1687. badCert = GrumpyCert()
  1688. cli, svr, p = connectedServerAndClient(
  1689. ServerClass=SecurableProto, ClientClass=SecurableProto
  1690. )
  1691. svr.certFactory = lambda: badCert
  1692. cli.callRemote(amp.StartTLS, tls_localCertificate=badCert)
  1693. p.flush()
  1694. # once for client once for server - but both fail
  1695. self.assertEqual(badCert.verifyCount, 2)
  1696. d = cli.callRemote(SecuredPing)
  1697. p.flush()
  1698. self.assertFailure(d, iosim.NativeOpenSSLError)
  1699. def test_negotiationFailedByClosing(self):
  1700. """
  1701. Verify that starting TLS and failing by way of a lost connection
  1702. notices that it is probably an SSL problem.
  1703. """
  1704. cli, svr, p = connectedServerAndClient(
  1705. ServerClass=SecurableProto, ClientClass=SecurableProto
  1706. )
  1707. droppyCert = DroppyCert(svr.transport)
  1708. svr.certFactory = lambda: droppyCert
  1709. cli.callRemote(amp.StartTLS, tls_localCertificate=droppyCert)
  1710. p.flush()
  1711. self.assertEqual(droppyCert.verifyCount, 2)
  1712. d = cli.callRemote(SecuredPing)
  1713. p.flush()
  1714. # it might be a good idea to move this exception somewhere more
  1715. # reasonable.
  1716. self.assertFailure(d, error.PeerVerifyError)
  1717. class TLSNotAvailableTests(TestCase):
  1718. """
  1719. Tests what happened when ssl is not available in current installation.
  1720. """
  1721. def setUp(self):
  1722. """
  1723. Disable ssl in amp.
  1724. """
  1725. self.ssl = amp.ssl
  1726. amp.ssl = None
  1727. def tearDown(self):
  1728. """
  1729. Restore ssl module.
  1730. """
  1731. amp.ssl = self.ssl
  1732. def test_callRemoteError(self):
  1733. """
  1734. Check that callRemote raises an exception when called with a
  1735. L{amp.StartTLS}.
  1736. """
  1737. cli, svr, p = connectedServerAndClient(
  1738. ServerClass=SecurableProto, ClientClass=SecurableProto
  1739. )
  1740. okc = OKCert()
  1741. svr.certFactory = lambda: okc
  1742. return self.assertFailure(
  1743. cli.callRemote(
  1744. amp.StartTLS,
  1745. tls_localCertificate=okc,
  1746. tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
  1747. ),
  1748. RuntimeError,
  1749. )
  1750. def test_messageReceivedError(self):
  1751. """
  1752. When a client with SSL enabled talks to a server without SSL, it
  1753. should return a meaningful error.
  1754. """
  1755. svr = SecurableProto()
  1756. okc = OKCert()
  1757. svr.certFactory = lambda: okc
  1758. box = amp.Box()
  1759. box[b"_command"] = b"StartTLS"
  1760. box[b"_ask"] = b"1"
  1761. boxes = []
  1762. svr.sendBox = boxes.append
  1763. svr.makeConnection(StringTransport())
  1764. svr.ampBoxReceived(box)
  1765. self.assertEqual(
  1766. boxes,
  1767. [
  1768. {
  1769. b"_error_code": b"TLS_ERROR",
  1770. b"_error": b"1",
  1771. b"_error_description": b"TLS not available",
  1772. }
  1773. ],
  1774. )
  1775. class InheritedError(Exception):
  1776. """
  1777. This error is used to check inheritance.
  1778. """
  1779. class OtherInheritedError(Exception):
  1780. """
  1781. This is a distinct error for checking inheritance.
  1782. """
  1783. class BaseCommand(amp.Command):
  1784. """
  1785. This provides a command that will be subclassed.
  1786. """
  1787. errors: Dict[Type[Exception], bytes] = {InheritedError: b"INHERITED_ERROR"}
  1788. class InheritedCommand(BaseCommand):
  1789. """
  1790. This is a command which subclasses another command but does not override
  1791. anything.
  1792. """
  1793. class AddErrorsCommand(BaseCommand):
  1794. """
  1795. This is a command which subclasses another command but adds errors to the
  1796. list.
  1797. """
  1798. arguments = [(b"other", amp.Boolean())]
  1799. errors: Dict[Type[Exception], bytes] = {
  1800. OtherInheritedError: b"OTHER_INHERITED_ERROR"
  1801. }
  1802. class NormalCommandProtocol(amp.AMP):
  1803. """
  1804. This is a protocol which responds to L{BaseCommand}, and is used to test
  1805. that inheritance does not interfere with the normal handling of errors.
  1806. """
  1807. def resp(self):
  1808. raise InheritedError()
  1809. BaseCommand.responder(resp)
  1810. class InheritedCommandProtocol(amp.AMP):
  1811. """
  1812. This is a protocol which responds to L{InheritedCommand}, and is used to
  1813. test that inherited commands inherit their bases' errors if they do not
  1814. respond to any of their own.
  1815. """
  1816. def resp(self):
  1817. raise InheritedError()
  1818. InheritedCommand.responder(resp)
  1819. class AddedCommandProtocol(amp.AMP):
  1820. """
  1821. This is a protocol which responds to L{AddErrorsCommand}, and is used to
  1822. test that inherited commands can add their own new types of errors, but
  1823. still respond in the same way to their parents types of errors.
  1824. """
  1825. def resp(self, other):
  1826. if other:
  1827. raise OtherInheritedError()
  1828. else:
  1829. raise InheritedError()
  1830. AddErrorsCommand.responder(resp)
  1831. class CommandInheritanceTests(TestCase):
  1832. """
  1833. These tests verify that commands inherit error conditions properly.
  1834. """
  1835. def errorCheck(self, err, proto, cmd, **kw):
  1836. """
  1837. Check that the appropriate kind of error is raised when a given command
  1838. is sent to a given protocol.
  1839. """
  1840. c, s, p = connectedServerAndClient(ServerClass=proto, ClientClass=proto)
  1841. d = c.callRemote(cmd, **kw)
  1842. d2 = self.failUnlessFailure(d, err)
  1843. p.flush()
  1844. return d2
  1845. def test_basicErrorPropagation(self):
  1846. """
  1847. Verify that errors specified in a superclass are respected normally
  1848. even if it has subclasses.
  1849. """
  1850. return self.errorCheck(InheritedError, NormalCommandProtocol, BaseCommand)
  1851. def test_inheritedErrorPropagation(self):
  1852. """
  1853. Verify that errors specified in a superclass command are propagated to
  1854. its subclasses.
  1855. """
  1856. return self.errorCheck(
  1857. InheritedError, InheritedCommandProtocol, InheritedCommand
  1858. )
  1859. def test_inheritedErrorAddition(self):
  1860. """
  1861. Verify that new errors specified in a subclass of an existing command
  1862. are honored even if the superclass defines some errors.
  1863. """
  1864. return self.errorCheck(
  1865. OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True
  1866. )
  1867. def test_additionWithOriginalError(self):
  1868. """
  1869. Verify that errors specified in a command's superclass are respected
  1870. even if that command defines new errors itself.
  1871. """
  1872. return self.errorCheck(
  1873. InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False
  1874. )
  1875. def _loseAndPass(err, proto):
  1876. # be specific, pass on the error to the client.
  1877. err.trap(error.ConnectionLost, error.ConnectionDone)
  1878. del proto.connectionLost
  1879. proto.connectionLost(err)
  1880. class LiveFireBase:
  1881. """
  1882. Utility for connected reactor-using tests.
  1883. """
  1884. def setUp(self):
  1885. """
  1886. Create an amp server and connect a client to it.
  1887. """
  1888. from twisted.internet import reactor
  1889. self.serverFactory = protocol.ServerFactory()
  1890. self.serverFactory.protocol = self.serverProto
  1891. self.clientFactory = protocol.ClientFactory()
  1892. self.clientFactory.protocol = self.clientProto
  1893. self.clientFactory.onMade = defer.Deferred()
  1894. self.serverFactory.onMade = defer.Deferred()
  1895. self.serverPort = reactor.listenTCP(0, self.serverFactory)
  1896. self.addCleanup(self.serverPort.stopListening)
  1897. self.clientConn = reactor.connectTCP(
  1898. "127.0.0.1", self.serverPort.getHost().port, self.clientFactory
  1899. )
  1900. self.addCleanup(self.clientConn.disconnect)
  1901. def getProtos(rlst):
  1902. self.cli = self.clientFactory.theProto
  1903. self.svr = self.serverFactory.theProto
  1904. dl = defer.DeferredList([self.clientFactory.onMade, self.serverFactory.onMade])
  1905. return dl.addCallback(getProtos)
  1906. def tearDown(self):
  1907. """
  1908. Cleanup client and server connections, and check the error got at
  1909. C{connectionLost}.
  1910. """
  1911. L = []
  1912. for conn in self.cli, self.svr:
  1913. if conn.transport is not None:
  1914. # depend on amp's function connection-dropping behavior
  1915. d = defer.Deferred().addErrback(_loseAndPass, conn)
  1916. conn.connectionLost = d.errback
  1917. conn.transport.loseConnection()
  1918. L.append(d)
  1919. return defer.gatherResults(L).addErrback(lambda first: first.value.subFailure)
  1920. def show(x):
  1921. import sys
  1922. sys.stdout.write(x + "\n")
  1923. sys.stdout.flush()
  1924. def tempSelfSigned():
  1925. from twisted.internet import ssl
  1926. sharedDN = ssl.DN(CN="shared")
  1927. key = ssl.KeyPair.generate()
  1928. cr = key.certificateRequest(sharedDN)
  1929. sscrd = key.signCertificateRequest(sharedDN, cr, lambda dn: True, 1234567)
  1930. cert = key.newCertificate(sscrd)
  1931. return cert
  1932. if ssl is not None:
  1933. tempcert = tempSelfSigned()
  1934. @skipIf(skipSSL, "SSL not available")
  1935. @skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
  1936. class LiveFireTLSTests(LiveFireBase, TestCase):
  1937. clientProto = SecurableProto
  1938. serverProto = SecurableProto
  1939. def test_liveFireCustomTLS(self):
  1940. """
  1941. Using real, live TLS, actually negotiate a connection.
  1942. This also looks at the 'peerCertificate' attribute's correctness, since
  1943. that's actually loaded using OpenSSL calls, but the main purpose is to
  1944. make sure that we didn't miss anything obvious in iosim about TLS
  1945. negotiations.
  1946. """
  1947. cert = tempcert
  1948. self.svr.verifyFactory = lambda: [cert]
  1949. self.svr.certFactory = lambda: cert
  1950. # only needed on the server, we specify the client below.
  1951. def secured(rslt):
  1952. x = cert.digest()
  1953. def pinged(rslt2):
  1954. # Interesting. OpenSSL won't even _tell_ us about the peer
  1955. # cert until we negotiate. we should be able to do this in
  1956. # 'secured' instead, but it looks like we can't. I think this
  1957. # is a bug somewhere far deeper than here.
  1958. self.assertEqual(x, self.cli.hostCertificate.digest())
  1959. self.assertEqual(x, self.cli.peerCertificate.digest())
  1960. self.assertEqual(x, self.svr.hostCertificate.digest())
  1961. self.assertEqual(x, self.svr.peerCertificate.digest())
  1962. return self.cli.callRemote(SecuredPing).addCallback(pinged)
  1963. return self.cli.callRemote(
  1964. amp.StartTLS, tls_localCertificate=cert, tls_verifyAuthorities=[cert]
  1965. ).addCallback(secured)
  1966. class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
  1967. """
  1968. Specific implementation of server side protocol with different
  1969. management of TLS.
  1970. """
  1971. def getTLSVars(self):
  1972. """
  1973. @return: the global C{tempcert} certificate as local certificate.
  1974. """
  1975. return dict(tls_localCertificate=tempcert)
  1976. amp.StartTLS.responder(getTLSVars)
  1977. @skipIf(skipSSL, "SSL not available")
  1978. @skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
  1979. class PlainVanillaLiveFireTests(LiveFireBase, TestCase):
  1980. clientProto = SimpleSymmetricCommandProtocol
  1981. serverProto = SimpleSymmetricCommandProtocol
  1982. def test_liveFireDefaultTLS(self):
  1983. """
  1984. Verify that out of the box, we can start TLS to at least encrypt the
  1985. connection, even if we don't have any certificates to use.
  1986. """
  1987. def secured(result):
  1988. return self.cli.callRemote(SecuredPing)
  1989. return self.cli.callRemote(amp.StartTLS).addCallback(secured)
  1990. @skipIf(skipSSL, "SSL not available")
  1991. @skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
  1992. class WithServerTLSVerificationTests(LiveFireBase, TestCase):
  1993. clientProto = SimpleSymmetricCommandProtocol
  1994. serverProto = SlightlySmartTLS
  1995. def test_anonymousVerifyingClient(self):
  1996. """
  1997. Verify that anonymous clients can verify server certificates.
  1998. """
  1999. def secured(result):
  2000. return self.cli.callRemote(SecuredPing)
  2001. return self.cli.callRemote(
  2002. amp.StartTLS, tls_verifyAuthorities=[tempcert]
  2003. ).addCallback(secured)
  2004. class ProtocolIncludingArgument(amp.Argument):
  2005. """
  2006. An L{amp.Argument} which encodes its parser and serializer
  2007. arguments *including the protocol* into its parsed and serialized
  2008. forms.
  2009. """
  2010. def fromStringProto(self, string, protocol):
  2011. """
  2012. Don't decode anything; just return all possible information.
  2013. @return: A two-tuple of the input string and the protocol.
  2014. """
  2015. return (string, protocol)
  2016. def toStringProto(self, obj, protocol):
  2017. """
  2018. Encode identifying information about L{object} and protocol
  2019. into a string for later verification.
  2020. @type obj: L{object}
  2021. @type protocol: L{amp.AMP}
  2022. """
  2023. ident = "%d:%d" % (id(obj), id(protocol))
  2024. return ident.encode("ascii")
  2025. class ProtocolIncludingCommand(amp.Command):
  2026. """
  2027. A command that has argument and response schemas which use
  2028. L{ProtocolIncludingArgument}.
  2029. """
  2030. arguments = [(b"weird", ProtocolIncludingArgument())]
  2031. response = [(b"weird", ProtocolIncludingArgument())]
  2032. class MagicSchemaCommand(amp.Command):
  2033. """
  2034. A command which overrides L{parseResponse}, L{parseArguments}, and
  2035. L{makeResponse}.
  2036. """
  2037. @classmethod
  2038. def parseResponse(self, strings, protocol):
  2039. """
  2040. Don't do any parsing, just jam the input strings and protocol
  2041. onto the C{protocol.parseResponseArguments} attribute as a
  2042. two-tuple. Return the original strings.
  2043. """
  2044. protocol.parseResponseArguments = (strings, protocol)
  2045. return strings
  2046. @classmethod
  2047. def parseArguments(cls, strings, protocol):
  2048. """
  2049. Don't do any parsing, just jam the input strings and protocol
  2050. onto the C{protocol.parseArgumentsArguments} attribute as a
  2051. two-tuple. Return the original strings.
  2052. """
  2053. protocol.parseArgumentsArguments = (strings, protocol)
  2054. return strings
  2055. @classmethod
  2056. def makeArguments(cls, objects, protocol):
  2057. """
  2058. Don't do any serializing, just jam the input strings and protocol
  2059. onto the C{protocol.makeArgumentsArguments} attribute as a
  2060. two-tuple. Return the original strings.
  2061. """
  2062. protocol.makeArgumentsArguments = (objects, protocol)
  2063. return objects
  2064. class NoNetworkProtocol(amp.AMP):
  2065. """
  2066. An L{amp.AMP} subclass which overrides private methods to avoid
  2067. testing the network. It also provides a responder for
  2068. L{MagicSchemaCommand} that does nothing, so that tests can test
  2069. aspects of the interaction of L{amp.Command}s and L{amp.AMP}.
  2070. @ivar parseArgumentsArguments: Arguments that have been passed to any
  2071. L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
  2072. this protocol.
  2073. @ivar parseResponseArguments: Responses that have been returned from a
  2074. L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
  2075. this protocol.
  2076. @ivar makeArgumentsArguments: Arguments that have been serialized by any
  2077. L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
  2078. this protocol.
  2079. """
  2080. def _sendBoxCommand(self, commandName, strings, requiresAnswer):
  2081. """
  2082. Return a Deferred which fires with the original strings.
  2083. """
  2084. return defer.succeed(strings)
  2085. MagicSchemaCommand.responder(lambda s, weird: {})
  2086. class MyBox(dict):
  2087. """
  2088. A unique dict subclass.
  2089. """
  2090. class ProtocolIncludingCommandWithDifferentCommandType(ProtocolIncludingCommand):
  2091. """
  2092. A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox}
  2093. """
  2094. commandType = MyBox # type: ignore[assignment]
  2095. class CommandTests(TestCase):
  2096. """
  2097. Tests for L{amp.Argument} and L{amp.Command}.
  2098. """
  2099. def test_argumentInterface(self):
  2100. """
  2101. L{Argument} instances provide L{amp.IArgumentType}.
  2102. """
  2103. self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument()))
  2104. def test_parseResponse(self):
  2105. """
  2106. There should be a class method of Command which accepts a
  2107. mapping of argument names to serialized forms and returns a
  2108. similar mapping whose values have been parsed via the
  2109. Command's response schema.
  2110. """
  2111. protocol = object()
  2112. result = b"whatever"
  2113. strings = {b"weird": result}
  2114. self.assertEqual(
  2115. ProtocolIncludingCommand.parseResponse(strings, protocol),
  2116. {"weird": (result, protocol)},
  2117. )
  2118. def test_callRemoteCallsParseResponse(self):
  2119. """
  2120. Making a remote call on a L{amp.Command} subclass which
  2121. overrides the C{parseResponse} method should call that
  2122. C{parseResponse} method to get the response.
  2123. """
  2124. client = NoNetworkProtocol()
  2125. thingy = b"weeoo"
  2126. response = client.callRemote(MagicSchemaCommand, weird=thingy)
  2127. def gotResponse(ign):
  2128. self.assertEqual(client.parseResponseArguments, ({"weird": thingy}, client))
  2129. response.addCallback(gotResponse)
  2130. return response
  2131. def test_parseArguments(self):
  2132. """
  2133. There should be a class method of L{amp.Command} which accepts
  2134. a mapping of argument names to serialized forms and returns a
  2135. similar mapping whose values have been parsed via the
  2136. command's argument schema.
  2137. """
  2138. protocol = object()
  2139. result = b"whatever"
  2140. strings = {b"weird": result}
  2141. self.assertEqual(
  2142. ProtocolIncludingCommand.parseArguments(strings, protocol),
  2143. {"weird": (result, protocol)},
  2144. )
  2145. def test_responderCallsParseArguments(self):
  2146. """
  2147. Making a remote call on a L{amp.Command} subclass which
  2148. overrides the C{parseArguments} method should call that
  2149. C{parseArguments} method to get the arguments.
  2150. """
  2151. protocol = NoNetworkProtocol()
  2152. responder = protocol.locateResponder(MagicSchemaCommand.commandName)
  2153. argument = object()
  2154. response = responder(dict(weird=argument))
  2155. response.addCallback(
  2156. lambda ign: self.assertEqual(
  2157. protocol.parseArgumentsArguments, ({"weird": argument}, protocol)
  2158. )
  2159. )
  2160. return response
  2161. def test_makeArguments(self):
  2162. """
  2163. There should be a class method of L{amp.Command} which accepts
  2164. a mapping of argument names to objects and returns a similar
  2165. mapping whose values have been serialized via the command's
  2166. argument schema.
  2167. """
  2168. protocol = object()
  2169. argument = object()
  2170. objects = {"weird": argument}
  2171. ident = "%d:%d" % (id(argument), id(protocol))
  2172. self.assertEqual(
  2173. ProtocolIncludingCommand.makeArguments(objects, protocol),
  2174. {b"weird": ident.encode("ascii")},
  2175. )
  2176. def test_makeArgumentsUsesCommandType(self):
  2177. """
  2178. L{amp.Command.makeArguments}'s return type should be the type
  2179. of the result of L{amp.Command.commandType}.
  2180. """
  2181. protocol = object()
  2182. objects = {"weird": b"whatever"}
  2183. result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments(
  2184. objects, protocol
  2185. )
  2186. self.assertIs(type(result), MyBox)
  2187. def test_callRemoteCallsMakeArguments(self):
  2188. """
  2189. Making a remote call on a L{amp.Command} subclass which
  2190. overrides the C{makeArguments} method should call that
  2191. C{makeArguments} method to get the response.
  2192. """
  2193. client = NoNetworkProtocol()
  2194. argument = object()
  2195. response = client.callRemote(MagicSchemaCommand, weird=argument)
  2196. def gotResponse(ign):
  2197. self.assertEqual(
  2198. client.makeArgumentsArguments, ({"weird": argument}, client)
  2199. )
  2200. response.addCallback(gotResponse)
  2201. return response
  2202. def test_extraArgumentsDisallowed(self):
  2203. """
  2204. L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects
  2205. dictionary passed to it includes a key which does not correspond to the
  2206. Python identifier for a defined argument.
  2207. """
  2208. self.assertRaises(
  2209. amp.InvalidSignature,
  2210. Hello.makeArguments,
  2211. dict(hello="hello", bogusArgument=object()),
  2212. None,
  2213. )
  2214. def test_wireSpellingDisallowed(self):
  2215. """
  2216. If a command argument conflicts with a Python keyword, the
  2217. untransformed argument name is not allowed as a key in the dictionary
  2218. passed to L{Command.makeArguments}. If it is supplied,
  2219. L{amp.InvalidSignature} is raised.
  2220. This may be a pointless implementation restriction which may be lifted.
  2221. The current behavior is tested to verify that such arguments are not
  2222. silently dropped on the floor (the previous behavior).
  2223. """
  2224. self.assertRaises(
  2225. amp.InvalidSignature,
  2226. Hello.makeArguments,
  2227. dict(hello="required", **{"print": "print value"}),
  2228. None,
  2229. )
  2230. def test_commandNameDefaultsToClassNameAsByteString(self):
  2231. """
  2232. A L{Command} subclass without a defined C{commandName} that's
  2233. not a byte string.
  2234. """
  2235. class NewCommand(amp.Command):
  2236. """
  2237. A new command.
  2238. """
  2239. self.assertEqual(b"NewCommand", NewCommand.commandName)
  2240. def test_commandNameMustBeAByteString(self):
  2241. """
  2242. A L{Command} subclass cannot be defined with a C{commandName} that's
  2243. not a byte string.
  2244. """
  2245. error = self.assertRaises(
  2246. TypeError, type, "NewCommand", (amp.Command,), {"commandName": "FOO"}
  2247. )
  2248. self.assertRegex(
  2249. str(error), "^Command names must be byte strings, got: u?'FOO'$"
  2250. )
  2251. def test_commandArgumentsMustBeNamedWithByteStrings(self):
  2252. """
  2253. A L{Command} subclass's C{arguments} must have byte string names.
  2254. """
  2255. error = self.assertRaises(
  2256. TypeError,
  2257. type,
  2258. "NewCommand",
  2259. (amp.Command,),
  2260. {"arguments": [("foo", None)]},
  2261. )
  2262. self.assertRegex(
  2263. str(error), "^Argument names must be byte strings, got: u?'foo'$"
  2264. )
  2265. def test_commandResponseMustBeNamedWithByteStrings(self):
  2266. """
  2267. A L{Command} subclass's C{response} must have byte string names.
  2268. """
  2269. error = self.assertRaises(
  2270. TypeError, type, "NewCommand", (amp.Command,), {"response": [("foo", None)]}
  2271. )
  2272. self.assertRegex(
  2273. str(error), "^Response names must be byte strings, got: u?'foo'$"
  2274. )
  2275. def test_commandErrorsIsConvertedToDict(self):
  2276. """
  2277. A L{Command} subclass's C{errors} is coerced into a C{dict}.
  2278. """
  2279. class NewCommand(amp.Command):
  2280. errors = [(ZeroDivisionError, b"ZDE")]
  2281. self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.errors)
  2282. def test_commandErrorsMustUseBytesForOnWireRepresentation(self):
  2283. """
  2284. A L{Command} subclass's C{errors} must map exceptions to byte strings.
  2285. """
  2286. error = self.assertRaises(
  2287. TypeError,
  2288. type,
  2289. "NewCommand",
  2290. (amp.Command,),
  2291. {"errors": [(ZeroDivisionError, "foo")]},
  2292. )
  2293. self.assertRegex(str(error), "^Error names must be byte strings, got: u?'foo'$")
  2294. def test_commandFatalErrorsIsConvertedToDict(self):
  2295. """
  2296. A L{Command} subclass's C{fatalErrors} is coerced into a C{dict}.
  2297. """
  2298. class NewCommand(amp.Command):
  2299. fatalErrors = [(ZeroDivisionError, b"ZDE")]
  2300. self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.fatalErrors)
  2301. def test_commandFatalErrorsMustUseBytesForOnWireRepresentation(self):
  2302. """
  2303. A L{Command} subclass's C{fatalErrors} must map exceptions to byte
  2304. strings.
  2305. """
  2306. error = self.assertRaises(
  2307. TypeError,
  2308. type,
  2309. "NewCommand",
  2310. (amp.Command,),
  2311. {"fatalErrors": [(ZeroDivisionError, "foo")]},
  2312. )
  2313. self.assertRegex(
  2314. str(error), "^Fatal error names must be byte strings, " "got: u?'foo'$"
  2315. )
  2316. class ListOfTestsMixin:
  2317. """
  2318. Base class for testing L{ListOf}, a parameterized zero-or-more argument
  2319. type.
  2320. @ivar elementType: Subclasses should set this to an L{Argument}
  2321. instance. The tests will make a L{ListOf} using this.
  2322. @ivar strings: Subclasses should set this to a dictionary mapping some
  2323. number of keys -- as BYTE strings -- to the correct serialized form
  2324. for some example values. These should agree with what L{elementType}
  2325. produces/accepts.
  2326. @ivar objects: Subclasses should set this to a dictionary with the same
  2327. keys as C{strings} -- as NATIVE strings -- and with values which are
  2328. the lists which should serialize to the values in the C{strings}
  2329. dictionary.
  2330. """
  2331. def test_toBox(self):
  2332. """
  2333. L{ListOf.toBox} extracts the list of objects from the C{objects}
  2334. dictionary passed to it, using the C{name} key also passed to it,
  2335. serializes each of the elements in that list using the L{Argument}
  2336. instance previously passed to its initializer, combines the serialized
  2337. results, and inserts the result into the C{strings} dictionary using
  2338. the same C{name} key.
  2339. """
  2340. stringList = amp.ListOf(self.elementType)
  2341. strings = amp.AmpBox()
  2342. for key in self.objects:
  2343. stringList.toBox(key.encode("ascii"), strings, self.objects.copy(), None)
  2344. self.assertEqual(strings, self.strings)
  2345. def test_fromBox(self):
  2346. """
  2347. L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
  2348. """
  2349. stringList = amp.ListOf(self.elementType)
  2350. objects = {}
  2351. for key in self.strings:
  2352. stringList.fromBox(key, self.strings.copy(), objects, None)
  2353. self.assertEqual(objects, self.objects)
  2354. class ListOfStringsTests(TestCase, ListOfTestsMixin):
  2355. """
  2356. Tests for L{ListOf} combined with L{amp.String}.
  2357. """
  2358. elementType = amp.String()
  2359. strings = {
  2360. b"empty": b"",
  2361. b"single": b"\x00\x03foo",
  2362. b"multiple": b"\x00\x03bar\x00\x03baz\x00\x04quux",
  2363. }
  2364. objects = {"empty": [], "single": [b"foo"], "multiple": [b"bar", b"baz", b"quux"]}
  2365. class ListOfIntegersTests(TestCase, ListOfTestsMixin):
  2366. """
  2367. Tests for L{ListOf} combined with L{amp.Integer}.
  2368. """
  2369. elementType = amp.Integer()
  2370. huge = (
  2371. 9999999999999999999999999999999999999999999999999999999999
  2372. * 9999999999999999999999999999999999999999999999999999999999
  2373. )
  2374. strings = {
  2375. b"empty": b"",
  2376. b"single": b"\x00\x0210",
  2377. b"multiple": b"\x00\x011\x00\x0220\x00\x03500",
  2378. b"huge": b"\x00\x74%d" % (huge,),
  2379. b"negative": b"\x00\x02-1",
  2380. }
  2381. objects = {
  2382. "empty": [],
  2383. "single": [10],
  2384. "multiple": [1, 20, 500],
  2385. "huge": [huge],
  2386. "negative": [-1],
  2387. }
  2388. class ListOfUnicodeTests(TestCase, ListOfTestsMixin):
  2389. """
  2390. Tests for L{ListOf} combined with L{amp.Unicode}.
  2391. """
  2392. elementType = amp.Unicode()
  2393. strings = {
  2394. b"empty": b"",
  2395. b"single": b"\x00\x03foo",
  2396. b"multiple": b"\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world",
  2397. }
  2398. objects = {
  2399. "empty": [],
  2400. "single": ["foo"],
  2401. "multiple": ["\N{SNOWMAN}", "Hello", "world"],
  2402. }
  2403. class ListOfDecimalTests(TestCase, ListOfTestsMixin):
  2404. """
  2405. Tests for L{ListOf} combined with L{amp.Decimal}.
  2406. """
  2407. elementType = amp.Decimal()
  2408. strings = {
  2409. b"empty": b"",
  2410. b"single": b"\x00\x031.1",
  2411. b"extreme": b"\x00\x08Infinity\x00\x09-Infinity",
  2412. b"scientist": b"\x00\x083.141E+5\x00\x0a0.00003141\x00\x083.141E-7"
  2413. b"\x00\x09-3.141E+5\x00\x0b-0.00003141\x00\x09-3.141E-7",
  2414. b"engineer": (
  2415. b"\x00\x04"
  2416. + decimal.Decimal("0e6").to_eng_string().encode("ascii")
  2417. + b"\x00\x06"
  2418. + decimal.Decimal("1.5E-9").to_eng_string().encode("ascii")
  2419. ),
  2420. }
  2421. objects = {
  2422. "empty": [],
  2423. "single": [decimal.Decimal("1.1")],
  2424. "extreme": [
  2425. decimal.Decimal("Infinity"),
  2426. decimal.Decimal("-Infinity"),
  2427. ],
  2428. # exarkun objected to AMP supporting engineering notation because
  2429. # it was redundant, until we realised that 1E6 has less precision
  2430. # than 1000000 and is represented differently. But they compare
  2431. # and even hash equally. There were tears.
  2432. "scientist": [
  2433. decimal.Decimal("3.141E5"),
  2434. decimal.Decimal("3.141e-5"),
  2435. decimal.Decimal("3.141E-7"),
  2436. decimal.Decimal("-3.141e5"),
  2437. decimal.Decimal("-3.141E-5"),
  2438. decimal.Decimal("-3.141e-7"),
  2439. ],
  2440. "engineer": [
  2441. decimal.Decimal("0e6"),
  2442. decimal.Decimal("1.5E-9"),
  2443. ],
  2444. }
  2445. class ListOfDecimalNanTests(TestCase, ListOfTestsMixin):
  2446. """
  2447. Tests for L{ListOf} combined with L{amp.Decimal} for not-a-number values.
  2448. """
  2449. elementType = amp.Decimal()
  2450. strings = {
  2451. b"nan": b"\x00\x03NaN\x00\x04-NaN\x00\x04sNaN\x00\x05-sNaN",
  2452. }
  2453. objects = {
  2454. "nan": [
  2455. decimal.Decimal("NaN"),
  2456. decimal.Decimal("-NaN"),
  2457. decimal.Decimal("sNaN"),
  2458. decimal.Decimal("-sNaN"),
  2459. ]
  2460. }
  2461. def test_fromBox(self):
  2462. """
  2463. L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
  2464. """
  2465. # Helpers. Decimal.is_{qnan,snan,signed}() are new in 2.6 (or 2.5.2,
  2466. # but who's counting).
  2467. def is_qnan(decimal):
  2468. return "NaN" in str(decimal) and "sNaN" not in str(decimal)
  2469. def is_snan(decimal):
  2470. return "sNaN" in str(decimal)
  2471. def is_signed(decimal):
  2472. return "-" in str(decimal)
  2473. # NaN values have unusual equality semantics, so this method is
  2474. # overridden to compare the resulting objects in a way which works with
  2475. # NaNs.
  2476. stringList = amp.ListOf(self.elementType)
  2477. objects = {}
  2478. for key in self.strings:
  2479. stringList.fromBox(key, self.strings.copy(), objects, None)
  2480. n = objects["nan"]
  2481. self.assertTrue(is_qnan(n[0]) and not is_signed(n[0]))
  2482. self.assertTrue(is_qnan(n[1]) and is_signed(n[1]))
  2483. self.assertTrue(is_snan(n[2]) and not is_signed(n[2]))
  2484. self.assertTrue(is_snan(n[3]) and is_signed(n[3]))
  2485. class DecimalTests(TestCase):
  2486. """
  2487. Tests for L{amp.Decimal}.
  2488. """
  2489. def test_nonDecimal(self):
  2490. """
  2491. L{amp.Decimal.toString} raises L{ValueError} if passed an object which
  2492. is not an instance of C{decimal.Decimal}.
  2493. """
  2494. argument = amp.Decimal()
  2495. self.assertRaises(ValueError, argument.toString, "1.234")
  2496. self.assertRaises(ValueError, argument.toString, 1.234)
  2497. self.assertRaises(ValueError, argument.toString, 1234)
  2498. class FloatTests(TestCase):
  2499. """
  2500. Tests for L{amp.Float}.
  2501. """
  2502. def test_nonFloat(self):
  2503. """
  2504. L{amp.Float.toString} raises L{ValueError} if passed an object which
  2505. is not a L{float}.
  2506. """
  2507. argument = amp.Float()
  2508. self.assertRaises(ValueError, argument.toString, "1.234")
  2509. self.assertRaises(ValueError, argument.toString, b"1.234")
  2510. self.assertRaises(ValueError, argument.toString, 1234)
  2511. def test_float(self):
  2512. """
  2513. L{amp.Float.toString} returns a bytestring when it is given a L{float}.
  2514. """
  2515. argument = amp.Float()
  2516. self.assertEqual(argument.toString(1.234), b"1.234")
  2517. class ListOfDateTimeTests(TestCase, ListOfTestsMixin):
  2518. """
  2519. Tests for L{ListOf} combined with L{amp.DateTime}.
  2520. """
  2521. elementType = amp.DateTime()
  2522. strings = {
  2523. b"christmas": b"\x00\x202010-12-25T00:00:00.000000-00:00"
  2524. b"\x00\x202010-12-25T00:00:00.000000-00:00",
  2525. b"christmas in eu": b"\x00\x202010-12-25T00:00:00.000000+01:00",
  2526. b"christmas in iran": b"\x00\x202010-12-25T00:00:00.000000+03:30",
  2527. b"christmas in nyc": b"\x00\x202010-12-25T00:00:00.000000-05:00",
  2528. b"previous tests": b"\x00\x202010-12-25T00:00:00.000000+03:19"
  2529. b"\x00\x202010-12-25T00:00:00.000000-06:59",
  2530. }
  2531. objects = {
  2532. "christmas": [
  2533. datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=amp.utc),
  2534. datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 0, 0)),
  2535. ],
  2536. "christmas in eu": [
  2537. datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 1, 0)),
  2538. ],
  2539. "christmas in iran": [
  2540. datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 30)),
  2541. ],
  2542. "christmas in nyc": [
  2543. datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 5, 0)),
  2544. ],
  2545. "previous tests": [
  2546. datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 19)),
  2547. datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 6, 59)),
  2548. ],
  2549. }
  2550. class ListOfOptionalTests(TestCase):
  2551. """
  2552. Tests to ensure L{ListOf} AMP arguments can be omitted from AMP commands
  2553. via the 'optional' flag.
  2554. """
  2555. def test_requiredArgumentWithNoneValueRaisesTypeError(self):
  2556. """
  2557. L{ListOf.toBox} raises C{TypeError} when passed a value of L{None}
  2558. for the argument.
  2559. """
  2560. stringList = amp.ListOf(amp.Integer())
  2561. self.assertRaises(
  2562. TypeError,
  2563. stringList.toBox,
  2564. b"omitted",
  2565. amp.AmpBox(),
  2566. {"omitted": None},
  2567. None,
  2568. )
  2569. def test_optionalArgumentWithNoneValueOmitted(self):
  2570. """
  2571. L{ListOf.toBox} silently omits serializing any argument with a
  2572. value of L{None} that is designated as optional for the protocol.
  2573. """
  2574. stringList = amp.ListOf(amp.Integer(), optional=True)
  2575. strings = amp.AmpBox()
  2576. stringList.toBox(b"omitted", strings, {b"omitted": None}, None)
  2577. self.assertEqual(strings, {})
  2578. def test_requiredArgumentWithKeyMissingRaisesKeyError(self):
  2579. """
  2580. L{ListOf.toBox} raises C{KeyError} if the argument's key is not
  2581. present in the objects dictionary.
  2582. """
  2583. stringList = amp.ListOf(amp.Integer())
  2584. self.assertRaises(
  2585. KeyError,
  2586. stringList.toBox,
  2587. b"ommited",
  2588. amp.AmpBox(),
  2589. {"someOtherKey": 0},
  2590. None,
  2591. )
  2592. def test_optionalArgumentWithKeyMissingOmitted(self):
  2593. """
  2594. L{ListOf.toBox} silently omits serializing any argument designated
  2595. as optional whose key is not present in the objects dictionary.
  2596. """
  2597. stringList = amp.ListOf(amp.Integer(), optional=True)
  2598. stringList.toBox(b"ommited", amp.AmpBox(), {b"someOtherKey": 0}, None)
  2599. def test_omittedOptionalArgumentDeserializesAsNone(self):
  2600. """
  2601. L{ListOf.fromBox} correctly reverses the operation performed by
  2602. L{ListOf.toBox} for optional arguments.
  2603. """
  2604. stringList = amp.ListOf(amp.Integer(), optional=True)
  2605. objects = {}
  2606. stringList.fromBox(b"omitted", {}, objects, None)
  2607. self.assertEqual(objects, {"omitted": None})
  2608. @implementer(interfaces.IUNIXTransport)
  2609. class UNIXStringTransport:
  2610. """
  2611. An in-memory implementation of L{interfaces.IUNIXTransport} which collects
  2612. all data given to it for later inspection.
  2613. @ivar _queue: A C{list} of the data which has been given to this transport,
  2614. eg via C{write} or C{sendFileDescriptor}. Elements are two-tuples of a
  2615. string (identifying the destination of the data) and the data itself.
  2616. """
  2617. def __init__(self, descriptorFuzz):
  2618. """
  2619. @param descriptorFuzz: An offset to apply to descriptors.
  2620. @type descriptorFuzz: C{int}
  2621. """
  2622. self._fuzz = descriptorFuzz
  2623. self._queue = []
  2624. def sendFileDescriptor(self, descriptor):
  2625. self._queue.append(("fileDescriptorReceived", descriptor + self._fuzz))
  2626. def write(self, data):
  2627. self._queue.append(("dataReceived", data))
  2628. def writeSequence(self, seq):
  2629. for data in seq:
  2630. self.write(data)
  2631. def loseConnection(self):
  2632. self._queue.append(("connectionLost", Failure(error.ConnectionLost())))
  2633. def getHost(self):
  2634. return address.UNIXAddress("/tmp/some-path")
  2635. def getPeer(self):
  2636. return address.UNIXAddress("/tmp/another-path")
  2637. # Minimal evidence that we got the signatures right
  2638. verifyClass(interfaces.ITransport, UNIXStringTransport)
  2639. verifyClass(interfaces.IUNIXTransport, UNIXStringTransport)
  2640. class DescriptorTests(TestCase):
  2641. """
  2642. Tests for L{amp.Descriptor}, an argument type for passing a file descriptor
  2643. over an AMP connection over a UNIX domain socket.
  2644. """
  2645. def setUp(self):
  2646. self.fuzz = 3
  2647. self.transport = UNIXStringTransport(descriptorFuzz=self.fuzz)
  2648. self.protocol = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator()))
  2649. self.protocol.makeConnection(self.transport)
  2650. def test_fromStringProto(self):
  2651. """
  2652. L{Descriptor.fromStringProto} constructs a file descriptor value by
  2653. extracting a previously received file descriptor corresponding to the
  2654. wire value of the argument from the L{_DescriptorExchanger} state of the
  2655. protocol passed to it.
  2656. This is a whitebox test which involves direct L{_DescriptorExchanger}
  2657. state inspection.
  2658. """
  2659. argument = amp.Descriptor()
  2660. self.protocol.fileDescriptorReceived(5)
  2661. self.protocol.fileDescriptorReceived(3)
  2662. self.protocol.fileDescriptorReceived(1)
  2663. self.assertEqual(5, argument.fromStringProto("0", self.protocol))
  2664. self.assertEqual(3, argument.fromStringProto("1", self.protocol))
  2665. self.assertEqual(1, argument.fromStringProto("2", self.protocol))
  2666. self.assertEqual({}, self.protocol._descriptors)
  2667. def test_toStringProto(self):
  2668. """
  2669. To send a file descriptor, L{Descriptor.toStringProto} uses the
  2670. L{IUNIXTransport.sendFileDescriptor} implementation of the transport of
  2671. the protocol passed to it to copy the file descriptor. Each subsequent
  2672. descriptor sent over a particular AMP connection is assigned the next
  2673. integer value, starting from 0. The base ten string representation of
  2674. this value is the byte encoding of the argument.
  2675. This is a whitebox test which involves direct L{_DescriptorExchanger}
  2676. state inspection and mutation.
  2677. """
  2678. argument = amp.Descriptor()
  2679. self.assertEqual(b"0", argument.toStringProto(2, self.protocol))
  2680. self.assertEqual(
  2681. ("fileDescriptorReceived", 2 + self.fuzz), self.transport._queue.pop(0)
  2682. )
  2683. self.assertEqual(b"1", argument.toStringProto(4, self.protocol))
  2684. self.assertEqual(
  2685. ("fileDescriptorReceived", 4 + self.fuzz), self.transport._queue.pop(0)
  2686. )
  2687. self.assertEqual(b"2", argument.toStringProto(6, self.protocol))
  2688. self.assertEqual(
  2689. ("fileDescriptorReceived", 6 + self.fuzz), self.transport._queue.pop(0)
  2690. )
  2691. self.assertEqual({}, self.protocol._descriptors)
  2692. def test_roundTrip(self):
  2693. """
  2694. L{amp.Descriptor.fromBox} can interpret an L{amp.AmpBox} constructed by
  2695. L{amp.Descriptor.toBox} to reconstruct a file descriptor value.
  2696. """
  2697. name = "alpha"
  2698. nameAsBytes = name.encode("ascii")
  2699. strings = {}
  2700. descriptor = 17
  2701. sendObjects = {name: descriptor}
  2702. argument = amp.Descriptor()
  2703. argument.toBox(nameAsBytes, strings, sendObjects.copy(), self.protocol)
  2704. receiver = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator()))
  2705. for event in self.transport._queue:
  2706. getattr(receiver, event[0])(*event[1:])
  2707. receiveObjects = {}
  2708. argument.fromBox(nameAsBytes, strings.copy(), receiveObjects, receiver)
  2709. # Make sure we got the descriptor. Adjust by fuzz to be more convincing
  2710. # of having gone through L{IUNIXTransport.sendFileDescriptor}, not just
  2711. # converted to a string and then parsed back into an integer.
  2712. self.assertEqual(descriptor + self.fuzz, receiveObjects[name])
  2713. class DateTimeTests(TestCase):
  2714. """
  2715. Tests for L{amp.DateTime}, L{amp._FixedOffsetTZInfo}, and L{amp.utc}.
  2716. """
  2717. string = b"9876-01-23T12:34:56.054321-01:23"
  2718. tzinfo = tz("-", 1, 23)
  2719. object = datetime.datetime(9876, 1, 23, 12, 34, 56, 54321, tzinfo)
  2720. def test_invalidString(self):
  2721. """
  2722. L{amp.DateTime.fromString} raises L{ValueError} when passed a string
  2723. which does not represent a timestamp in the proper format.
  2724. """
  2725. d = amp.DateTime()
  2726. self.assertRaises(ValueError, d.fromString, "abc")
  2727. def test_invalidDatetime(self):
  2728. """
  2729. L{amp.DateTime.toString} raises L{ValueError} when passed a naive
  2730. datetime (a datetime with no timezone information).
  2731. """
  2732. d = amp.DateTime()
  2733. self.assertRaises(
  2734. ValueError, d.toString, datetime.datetime(2010, 12, 25, 0, 0, 0)
  2735. )
  2736. def test_fromString(self):
  2737. """
  2738. L{amp.DateTime.fromString} returns a C{datetime.datetime} with all of
  2739. its fields populated from the string passed to it.
  2740. """
  2741. argument = amp.DateTime()
  2742. value = argument.fromString(self.string)
  2743. self.assertEqual(value, self.object)
  2744. def test_toString(self):
  2745. """
  2746. L{amp.DateTime.toString} returns a C{str} in the wire format including
  2747. all of the information from the C{datetime.datetime} passed into it,
  2748. including the timezone offset.
  2749. """
  2750. argument = amp.DateTime()
  2751. value = argument.toString(self.object)
  2752. self.assertEqual(value, self.string)
  2753. class UTCTests(TestCase):
  2754. """
  2755. Tests for L{amp.utc}.
  2756. """
  2757. def test_tzname(self):
  2758. """
  2759. L{amp.utc.tzname} returns C{"+00:00"}.
  2760. """
  2761. self.assertEqual(amp.utc.tzname(None), "+00:00")
  2762. def test_dst(self):
  2763. """
  2764. L{amp.utc.dst} returns a zero timedelta.
  2765. """
  2766. self.assertEqual(amp.utc.dst(None), datetime.timedelta(0))
  2767. def test_utcoffset(self):
  2768. """
  2769. L{amp.utc.utcoffset} returns a zero timedelta.
  2770. """
  2771. self.assertEqual(amp.utc.utcoffset(None), datetime.timedelta(0))
  2772. def test_badSign(self):
  2773. """
  2774. L{amp._FixedOffsetTZInfo.fromSignHoursMinutes} raises L{ValueError} if
  2775. passed an offset sign other than C{'+'} or C{'-'}.
  2776. """
  2777. self.assertRaises(ValueError, tz, "?", 0, 0)
  2778. class RemoteAmpErrorTests(TestCase):
  2779. """
  2780. Tests for L{amp.RemoteAmpError}.
  2781. """
  2782. def test_stringMessage(self):
  2783. """
  2784. L{amp.RemoteAmpError} renders the given C{errorCode} (C{bytes}) and
  2785. C{description} into a native string.
  2786. """
  2787. error = amp.RemoteAmpError(b"BROKEN", "Something has broken")
  2788. self.assertEqual("Code<BROKEN>: Something has broken", str(error))
  2789. def test_stringMessageReplacesNonAsciiText(self):
  2790. """
  2791. When C{errorCode} contains non-ASCII characters, L{amp.RemoteAmpError}
  2792. renders then as backslash-escape sequences.
  2793. """
  2794. error = amp.RemoteAmpError(b"BROKEN-\xff", "Something has broken")
  2795. self.assertEqual("Code<BROKEN-\\xff>: Something has broken", str(error))
  2796. def test_stringMessageWithLocalFailure(self):
  2797. """
  2798. L{amp.RemoteAmpError} renders local errors with a "(local)" marker and
  2799. a brief traceback.
  2800. """
  2801. failure = Failure(Exception("Something came loose"))
  2802. error = amp.RemoteAmpError(b"BROKEN", "Something has broken", local=failure)
  2803. self.assertRegex(
  2804. str(error),
  2805. (
  2806. "^Code<BROKEN> [(]local[)]: Something has broken\n"
  2807. "Traceback [(]failure with no frames[)]: "
  2808. "<.+Exception.>: Something came loose\n"
  2809. ),
  2810. )