|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474 |
- # Copyright (c) Twisted Matrix Laboratories.
- # See LICENSE for details.
-
- """
- Test case for L{twisted.protocols.loopback}.
- """
-
- from __future__ import division, absolute_import
-
- from zope.interface import implementer
-
- from twisted.python.compat import intToBytes
- from twisted.trial import unittest
- from twisted.protocols import basic, loopback
- from twisted.internet import defer
- from twisted.internet.protocol import Protocol
- from twisted.internet.defer import Deferred
- from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer
- from twisted.internet import reactor, interfaces
-
-
- class SimpleProtocol(basic.LineReceiver):
- def __init__(self):
- self.conn = defer.Deferred()
- self.lines = []
- self.connLost = []
-
-
- def connectionMade(self):
- self.conn.callback(None)
-
-
- def lineReceived(self, line):
- self.lines.append(line)
-
-
- def connectionLost(self, reason):
- self.connLost.append(reason)
-
-
-
- class DoomProtocol(SimpleProtocol):
- i = 0
- def lineReceived(self, line):
- self.i += 1
- if self.i < 4:
- # by this point we should have connection closed,
- # but just in case we didn't we won't ever send 'Hello 4'
- self.sendLine(b"Hello " + intToBytes(self.i))
- SimpleProtocol.lineReceived(self, line)
- if self.lines[-1] == b"Hello 3":
- self.transport.loseConnection()
-
-
-
- class LoopbackTestCaseMixin:
- def testRegularFunction(self):
- s = SimpleProtocol()
- c = SimpleProtocol()
-
- def sendALine(result):
- s.sendLine(b"THIS IS LINE ONE!")
- s.transport.loseConnection()
- s.conn.addCallback(sendALine)
-
- def check(ignored):
- self.assertEqual(c.lines, [b"THIS IS LINE ONE!"])
- self.assertEqual(len(s.connLost), 1)
- self.assertEqual(len(c.connLost), 1)
- d = defer.maybeDeferred(self.loopbackFunc, s, c)
- d.addCallback(check)
- return d
-
-
- def testSneakyHiddenDoom(self):
- s = DoomProtocol()
- c = DoomProtocol()
-
- def sendALine(result):
- s.sendLine(b"DOOM LINE")
- s.conn.addCallback(sendALine)
-
- def check(ignored):
- self.assertEqual(s.lines, [b'Hello 1', b'Hello 2', b'Hello 3'])
- self.assertEqual(
- c.lines, [b'DOOM LINE', b'Hello 1', b'Hello 2', b'Hello 3'])
- self.assertEqual(len(s.connLost), 1)
- self.assertEqual(len(c.connLost), 1)
- d = defer.maybeDeferred(self.loopbackFunc, s, c)
- d.addCallback(check)
- return d
-
-
-
- class LoopbackAsyncTests(LoopbackTestCaseMixin, unittest.TestCase):
- loopbackFunc = staticmethod(loopback.loopbackAsync)
-
-
- def test_makeConnection(self):
- """
- Test that the client and server protocol both have makeConnection
- invoked on them by loopbackAsync.
- """
- class TestProtocol(Protocol):
- transport = None
- def makeConnection(self, transport):
- self.transport = transport
-
- server = TestProtocol()
- client = TestProtocol()
- loopback.loopbackAsync(server, client)
- self.assertIsNotNone(client.transport)
- self.assertIsNotNone(server.transport)
-
-
- def _hostpeertest(self, get, testServer):
- """
- Test one of the permutations of client/server host/peer.
- """
- class TestProtocol(Protocol):
- def makeConnection(self, transport):
- Protocol.makeConnection(self, transport)
- self.onConnection.callback(transport)
-
- if testServer:
- server = TestProtocol()
- d = server.onConnection = Deferred()
- client = Protocol()
- else:
- server = Protocol()
- client = TestProtocol()
- d = client.onConnection = Deferred()
-
- loopback.loopbackAsync(server, client)
-
- def connected(transport):
- host = getattr(transport, get)()
- self.assertTrue(IAddress.providedBy(host))
-
- return d.addCallback(connected)
-
-
- def test_serverHost(self):
- """
- Test that the server gets a transport with a properly functioning
- implementation of L{ITransport.getHost}.
- """
- return self._hostpeertest("getHost", True)
-
-
- def test_serverPeer(self):
- """
- Like C{test_serverHost} but for L{ITransport.getPeer}
- """
- return self._hostpeertest("getPeer", True)
-
-
- def test_clientHost(self, get="getHost"):
- """
- Test that the client gets a transport with a properly functioning
- implementation of L{ITransport.getHost}.
- """
- return self._hostpeertest("getHost", False)
-
-
- def test_clientPeer(self):
- """
- Like C{test_clientHost} but for L{ITransport.getPeer}.
- """
- return self._hostpeertest("getPeer", False)
-
-
- def _greetingtest(self, write, testServer):
- """
- Test one of the permutations of write/writeSequence client/server.
-
- @param write: The name of the method to test, C{"write"} or
- C{"writeSequence"}.
- """
- class GreeteeProtocol(Protocol):
- bytes = b""
- def dataReceived(self, bytes):
- self.bytes += bytes
- if self.bytes == b"bytes":
- self.received.callback(None)
-
- class GreeterProtocol(Protocol):
- def connectionMade(self):
- if write == "write":
- self.transport.write(b"bytes")
- else:
- self.transport.writeSequence([b"byt", b"es"])
-
- if testServer:
- server = GreeterProtocol()
- client = GreeteeProtocol()
- d = client.received = Deferred()
- else:
- server = GreeteeProtocol()
- d = server.received = Deferred()
- client = GreeterProtocol()
-
- loopback.loopbackAsync(server, client)
- return d
-
-
- def test_clientGreeting(self):
- """
- Test that on a connection where the client speaks first, the server
- receives the bytes sent by the client.
- """
- return self._greetingtest("write", False)
-
-
- def test_clientGreetingSequence(self):
- """
- Like C{test_clientGreeting}, but use C{writeSequence} instead of
- C{write} to issue the greeting.
- """
- return self._greetingtest("writeSequence", False)
-
-
- def test_serverGreeting(self, write="write"):
- """
- Test that on a connection where the server speaks first, the client
- receives the bytes sent by the server.
- """
- return self._greetingtest("write", True)
-
-
- def test_serverGreetingSequence(self):
- """
- Like C{test_serverGreeting}, but use C{writeSequence} instead of
- C{write} to issue the greeting.
- """
- return self._greetingtest("writeSequence", True)
-
-
- def _producertest(self, producerClass):
- toProduce = list(map(intToBytes, range(0, 10)))
-
- class ProducingProtocol(Protocol):
- def connectionMade(self):
- self.producer = producerClass(list(toProduce))
- self.producer.start(self.transport)
-
- class ReceivingProtocol(Protocol):
- bytes = b""
- def dataReceived(self, data):
- self.bytes += data
- if self.bytes == b''.join(toProduce):
- self.received.callback((client, server))
-
- server = ProducingProtocol()
- client = ReceivingProtocol()
- client.received = Deferred()
-
- loopback.loopbackAsync(server, client)
- return client.received
-
-
- def test_pushProducer(self):
- """
- Test a push producer registered against a loopback transport.
- """
- @implementer(IPushProducer)
- class PushProducer(object):
- resumed = False
-
- def __init__(self, toProduce):
- self.toProduce = toProduce
-
- def resumeProducing(self):
- self.resumed = True
-
- def start(self, consumer):
- self.consumer = consumer
- consumer.registerProducer(self, True)
- self._produceAndSchedule()
-
- def _produceAndSchedule(self):
- if self.toProduce:
- self.consumer.write(self.toProduce.pop(0))
- reactor.callLater(0, self._produceAndSchedule)
- else:
- self.consumer.unregisterProducer()
- d = self._producertest(PushProducer)
-
- def finished(results):
- (client, server) = results
- self.assertFalse(
- server.producer.resumed,
- "Streaming producer should not have been resumed.")
- d.addCallback(finished)
- return d
-
-
- def test_pullProducer(self):
- """
- Test a pull producer registered against a loopback transport.
- """
- @implementer(IPullProducer)
- class PullProducer(object):
- def __init__(self, toProduce):
- self.toProduce = toProduce
-
- def start(self, consumer):
- self.consumer = consumer
- self.consumer.registerProducer(self, False)
-
- def resumeProducing(self):
- self.consumer.write(self.toProduce.pop(0))
- if not self.toProduce:
- self.consumer.unregisterProducer()
- return self._producertest(PullProducer)
-
-
- def test_writeNotReentrant(self):
- """
- L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
- method while that protocol's transport's C{write} method is higher up
- on the stack.
- """
- class Server(Protocol):
- def dataReceived(self, bytes):
- self.transport.write(b"bytes")
-
- class Client(Protocol):
- ready = False
-
- def connectionMade(self):
- reactor.callLater(0, self.go)
-
- def go(self):
- self.transport.write(b"foo")
- self.ready = True
-
- def dataReceived(self, bytes):
- self.wasReady = self.ready
- self.transport.loseConnection()
-
- server = Server()
- client = Client()
- d = loopback.loopbackAsync(client, server)
- def cbFinished(ignored):
- self.assertTrue(client.wasReady)
- d.addCallback(cbFinished)
- return d
-
-
- def test_pumpPolicy(self):
- """
- The callable passed as the value for the C{pumpPolicy} parameter to
- L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
- and a protocol to which they should be delivered.
- """
- pumpCalls = []
- def dummyPolicy(queue, target):
- bytes = []
- while queue:
- bytes.append(queue.get())
- pumpCalls.append((target, bytes))
-
- client = Protocol()
- server = Protocol()
-
- finished = loopback.loopbackAsync(server, client, dummyPolicy)
- self.assertEqual(pumpCalls, [])
-
- client.transport.write(b"foo")
- client.transport.write(b"bar")
- server.transport.write(b"baz")
- server.transport.write(b"quux")
- server.transport.loseConnection()
-
- def cbComplete(ignored):
- self.assertEqual(
- pumpCalls,
- # The order here is somewhat arbitrary. The implementation
- # happens to always deliver data to the client first.
- [(client, [b"baz", b"quux", None]),
- (server, [b"foo", b"bar"])])
- finished.addCallback(cbComplete)
- return finished
-
-
- def test_identityPumpPolicy(self):
- """
- L{identityPumpPolicy} is a pump policy which calls the target's
- C{dataReceived} method one for each string in the queue passed to it.
- """
- bytes = []
- client = Protocol()
- client.dataReceived = bytes.append
- queue = loopback._LoopbackQueue()
- queue.put(b"foo")
- queue.put(b"bar")
- queue.put(None)
-
- loopback.identityPumpPolicy(queue, client)
-
- self.assertEqual(bytes, [b"foo", b"bar"])
-
-
- def test_collapsingPumpPolicy(self):
- """
- L{collapsingPumpPolicy} is a pump policy which calls the target's
- C{dataReceived} only once with all of the strings in the queue passed
- to it joined together.
- """
- bytes = []
- client = Protocol()
- client.dataReceived = bytes.append
- queue = loopback._LoopbackQueue()
- queue.put(b"foo")
- queue.put(b"bar")
- queue.put(None)
-
- loopback.collapsingPumpPolicy(queue, client)
-
- self.assertEqual(bytes, [b"foobar"])
-
-
-
- class LoopbackTCPTests(LoopbackTestCaseMixin, unittest.TestCase):
- loopbackFunc = staticmethod(loopback.loopbackTCP)
-
-
-
- class LoopbackUNIXTests(LoopbackTestCaseMixin, unittest.TestCase):
- loopbackFunc = staticmethod(loopback.loopbackUNIX)
-
- if interfaces.IReactorUNIX(reactor, None) is None:
- skip = "Current reactor does not support UNIX sockets"
-
-
-
- class LoopbackRelayTest(unittest.TestCase):
- """
- Test for L{twisted.protocols.loopback.LoopbackRelay}
- """
- class Receiver(Protocol):
- """
- Simple Receiver class used for testing LoopbackRelay
- """
- data = b''
- def dataReceived(self, data):
- "Accumulate received data for verification"
- self.data += data
-
-
- def test_write(self):
- "Test to verify that the write function works as expected"
- receiver = self.Receiver()
- relay = loopback.LoopbackRelay(receiver)
- relay.write(b'abc')
- relay.write(b'def')
- self.assertEqual(receiver.data, b'')
- relay.clearBuffer()
- self.assertEqual(receiver.data, b'abcdef')
-
-
- def test_writeSequence(self):
- "Test to verify that the writeSequence function works as expected"
- receiver = self.Receiver()
- relay = loopback.LoopbackRelay(receiver)
- relay.writeSequence(
- [b'The ', b'quick ', b'brown ', b'fox '])
- relay.writeSequence(
- [b'jumps ', b'over ', b'the lazy dog'])
- self.assertEqual(receiver.data, b'')
- relay.clearBuffer()
- self.assertEqual(
- receiver.data, b'The quick brown fox jumps over the lazy dog')
|