You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

iosim.py 17KB

5 years ago

  1. # -*- test-case-name: twisted.test.test_amp,twisted.test.test_iosim -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Utilities and helpers for simulating a network
  6. """
  7. from __future__ import absolute_import, division, print_function
  8. import itertools
  9. try:
  10. from OpenSSL.SSL import Error as NativeOpenSSLError
  11. except ImportError:
  12. pass
  13. from zope.interface import implementer, directlyProvides
  14. from twisted.internet.endpoints import TCP4ClientEndpoint, TCP4ServerEndpoint
  15. from twisted.internet.protocol import Factory, Protocol
  16. from twisted.internet.error import ConnectionRefusedError
  17. from twisted.python.failure import Failure
  18. from twisted.internet import error
  19. from twisted.internet import interfaces
  20. from twisted.internet.testing import MemoryReactorClock
  21. class TLSNegotiation:
  22. def __init__(self, obj, connectState):
  23. self.obj = obj
  24. self.connectState = connectState
  25. self.sent = False
  26. self.readyToSend = connectState
  27. def __repr__(self):
  28. return 'TLSNegotiation(%r)' % (self.obj,)
  29. def pretendToVerify(self, other, tpt):
  30. # Set the transport problems list here? disconnections?
  31. # hmmmmm... need some negative path tests.
  32. if not self.obj.iosimVerify(other.obj):
  33. tpt.disconnectReason = NativeOpenSSLError()
  34. tpt.loseConnection()
  35. @implementer(interfaces.IAddress)
  36. class FakeAddress(object):
  37. """
  38. The default address type for the host and peer of L{FakeTransport}
  39. connections.
  40. """
  41. @implementer(interfaces.ITransport,
  42. interfaces.ITLSTransport)
  43. class FakeTransport:
  44. """
  45. A wrapper around a file-like object to make it behave as a Transport.
  46. This doesn't actually stream the file to the attached protocol,
  47. and is thus useful mainly as a utility for debugging protocols.
  48. """
  49. _nextserial = staticmethod(lambda counter=itertools.count(): next(counter))
  50. closed = 0
  51. disconnecting = 0
  52. disconnected = 0
  53. disconnectReason = error.ConnectionDone("Connection done")
  54. producer = None
  55. streamingProducer = 0
  56. tls = None
  57. def __init__(self, protocol, isServer, hostAddress=None, peerAddress=None):
  58. """
  59. @param protocol: This transport will deliver bytes to this protocol.
  60. @type protocol: L{IProtocol} provider
  61. @param isServer: C{True} if this is the accepting side of the
  62. connection, C{False} if it is the connecting side.
  63. @type isServer: L{bool}
  64. @param hostAddress: The value to return from C{getHost}. L{None}
  65. results in a new L{FakeAddress} being created to use as the value.
  66. @type hostAddress: L{IAddress} provider or L{None}
  67. @param peerAddress: The value to return from C{getPeer}. L{None}
  68. results in a new L{FakeAddress} being created to use as the value.
  69. @type peerAddress: L{IAddress} provider or L{None}
  70. """
  71. self.protocol = protocol
  72. self.isServer = isServer
  73. self.stream = []
  74. self.serial = self._nextserial()
  75. if hostAddress is None:
  76. hostAddress = FakeAddress()
  77. self.hostAddress = hostAddress
  78. if peerAddress is None:
  79. peerAddress = FakeAddress()
  80. self.peerAddress = peerAddress
  81. def __repr__(self):
  82. return 'FakeTransport<%s,%s,%s>' % (
  83. self.isServer and 'S' or 'C', self.serial,
  84. self.protocol.__class__.__name__)
  85. def write(self, data):
  86. # If transport is closed, we should accept writes but drop the data.
  87. if self.disconnecting:
  88. return
  89. if self.tls is not None:
  90. self.tlsbuf.append(data)
  91. else:
  92. self.stream.append(data)
  93. def _checkProducer(self):
  94. # Cheating; this is called at "idle" times to allow producers to be
  95. # found and dealt with
  96. if self.producer and not self.streamingProducer:
  97. self.producer.resumeProducing()
  98. def registerProducer(self, producer, streaming):
  99. """
  100. From abstract.FileDescriptor
  101. """
  102. self.producer = producer
  103. self.streamingProducer = streaming
  104. if not streaming:
  105. producer.resumeProducing()
  106. def unregisterProducer(self):
  107. self.producer = None
  108. def stopConsuming(self):
  109. self.unregisterProducer()
  110. self.loseConnection()
  111. def writeSequence(self, iovec):
  112. self.write(b"".join(iovec))
  113. def loseConnection(self):
  114. self.disconnecting = True
  115. def abortConnection(self):
  116. """
  117. For the time being, this is the same as loseConnection; no buffered
  118. data will be lost.
  119. """
  120. self.disconnecting = True
  121. def reportDisconnect(self):
  122. if self.tls is not None:
  123. # We were in the middle of negotiating! Must have been a TLS
  124. # problem.
  125. err = NativeOpenSSLError()
  126. else:
  127. err = self.disconnectReason
  128. self.protocol.connectionLost(Failure(err))
  129. def logPrefix(self):
  130. """
  131. Identify this transport/event source to the logging system.
  132. """
  133. return "iosim"
  134. def getPeer(self):
  135. return self.peerAddress
  136. def getHost(self):
  137. return self.hostAddress
  138. def resumeProducing(self):
  139. # Never sends data anyways
  140. pass
  141. def pauseProducing(self):
  142. # Never sends data anyways
  143. pass
  144. def stopProducing(self):
  145. self.loseConnection()
  146. def startTLS(self, contextFactory, beNormal=True):
  147. # Nothing's using this feature yet, but startTLS has an undocumented
  148. # second argument which defaults to true; if set to False, servers will
  149. # behave like clients and clients will behave like servers.
  150. connectState = self.isServer ^ beNormal
  151. self.tls = TLSNegotiation(contextFactory, connectState)
  152. self.tlsbuf = []
  153. def getOutBuffer(self):
  154. """
  155. Get the pending writes from this transport, clearing them from the
  156. pending buffer.
  157. @return: the bytes written with C{transport.write}
  158. @rtype: L{bytes}
  159. """
  160. S = self.stream
  161. if S:
  162. self.stream = []
  163. return b''.join(S)
  164. elif self.tls is not None:
  165. if self.tls.readyToSend:
  166. # Only _send_ the TLS negotiation "packet" if I'm ready to.
  167. self.tls.sent = True
  168. return self.tls
  169. else:
  170. return None
  171. else:
  172. return None
  173. def bufferReceived(self, buf):
  174. if isinstance(buf, TLSNegotiation):
  175. assert self.tls is not None # By the time you're receiving a
  176. # negotiation, you have to have called
  177. # startTLS already.
  178. if self.tls.sent:
  179. self.tls.pretendToVerify(buf, self)
  180. self.tls = None # We're done with the handshake if we've gotten
  181. # this far... although maybe it failed...?
  182. # TLS started! Unbuffer...
  183. b, self.tlsbuf = self.tlsbuf, None
  184. self.writeSequence(b)
  185. directlyProvides(self, interfaces.ISSLTransport)
  186. else:
  187. # We haven't sent our own TLS negotiation: time to do that!
  188. self.tls.readyToSend = True
  189. else:
  190. self.protocol.dataReceived(buf)
  191. def makeFakeClient(clientProtocol):
  192. """
  193. Create and return a new in-memory transport hooked up to the given protocol.
  194. @param clientProtocol: The client protocol to use.
  195. @type clientProtocol: L{IProtocol} provider
  196. @return: The transport.
  197. @rtype: L{FakeTransport}
  198. """
  199. return FakeTransport(clientProtocol, isServer=False)
  200. def makeFakeServer(serverProtocol):
  201. """
  202. Create and return a new in-memory transport hooked up to the given protocol.
  203. @param serverProtocol: The server protocol to use.
  204. @type serverProtocol: L{IProtocol} provider
  205. @return: The transport.
  206. @rtype: L{FakeTransport}
  207. """
  208. return FakeTransport(serverProtocol, isServer=True)
  209. class IOPump:
  210. """
  211. Utility to pump data between clients and servers for protocol testing.
  212. Perhaps this is a utility worthy of being in protocol.py?
  213. """
  214. def __init__(self, client, server, clientIO, serverIO, debug):
  215. self.client = client
  216. self.server = server
  217. self.clientIO = clientIO
  218. self.serverIO = serverIO
  219. self.debug = debug
  220. def flush(self, debug=False):
  221. """
  222. Pump until there is no more input or output.
  223. Returns whether any data was moved.
  224. """
  225. result = False
  226. for x in range(1000):
  227. if self.pump(debug):
  228. result = True
  229. else:
  230. break
  231. else:
  232. assert 0, "Too long"
  233. return result
  234. def pump(self, debug=False):
  235. """
  236. Move data back and forth.
  237. Returns whether any data was moved.
  238. """
  239. if self.debug or debug:
  240. print('-- GLUG --')
  241. sData = self.serverIO.getOutBuffer()
  242. cData = self.clientIO.getOutBuffer()
  243. self.clientIO._checkProducer()
  244. self.serverIO._checkProducer()
  245. if self.debug or debug:
  246. print('.')
  247. # XXX slightly buggy in the face of incremental output
  248. if cData:
  249. print('C: ' + repr(cData))
  250. if sData:
  251. print('S: ' + repr(sData))
  252. if cData:
  253. self.serverIO.bufferReceived(cData)
  254. if sData:
  255. self.clientIO.bufferReceived(sData)
  256. if cData or sData:
  257. return True
  258. if (self.serverIO.disconnecting and
  259. not self.serverIO.disconnected):
  260. if self.debug or debug:
  261. print('* C')
  262. self.serverIO.disconnected = True
  263. self.clientIO.disconnecting = True
  264. self.clientIO.reportDisconnect()
  265. return True
  266. if self.clientIO.disconnecting and not self.clientIO.disconnected:
  267. if self.debug or debug:
  268. print('* S')
  269. self.clientIO.disconnected = True
  270. self.serverIO.disconnecting = True
  271. self.serverIO.reportDisconnect()
  272. return True
  273. return False
  274. def connect(serverProtocol, serverTransport, clientProtocol, clientTransport,
  275. debug=False, greet=True):
  276. """
  277. Create a new L{IOPump} connecting two protocols.
  278. @param serverProtocol: The protocol to use on the accepting side of the
  279. connection.
  280. @type serverProtocol: L{IProtocol} provider
  281. @param serverTransport: The transport to associate with C{serverProtocol}.
  282. @type serverTransport: L{FakeTransport}
  283. @param clientProtocol: The protocol to use on the initiating side of the
  284. connection.
  285. @type clientProtocol: L{IProtocol} provider
  286. @param clientTransport: The transport to associate with C{clientProtocol}.
  287. @type clientTransport: L{FakeTransport}
  288. @param debug: A flag indicating whether to log information about what the
  289. L{IOPump} is doing.
  290. @type debug: L{bool}
  291. @param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
  292. returning to put the protocols into their post-handshake or
  293. post-server-greeting state?
  294. @type greet: L{bool}
  295. @return: An L{IOPump} which connects C{serverProtocol} and
  296. C{clientProtocol} and delivers bytes between them when it is pumped.
  297. @rtype: L{IOPump}
  298. """
  299. serverProtocol.makeConnection(serverTransport)
  300. clientProtocol.makeConnection(clientTransport)
  301. pump = IOPump(
  302. clientProtocol, serverProtocol, clientTransport, serverTransport, debug
  303. )
  304. if greet:
  305. # Kick off server greeting, etc
  306. pump.flush()
  307. return pump
  308. def connectedServerAndClient(ServerClass, ClientClass,
  309. clientTransportFactory=makeFakeClient,
  310. serverTransportFactory=makeFakeServer,
  311. debug=False, greet=True):
  312. """
  313. Connect a given server and client class to each other.
  314. @param ServerClass: a callable that produces the server-side protocol.
  315. @type ServerClass: 0-argument callable returning L{IProtocol} provider.
  316. @param ClientClass: like C{ServerClass} but for the other side of the
  317. connection.
  318. @type ClientClass: 0-argument callable returning L{IProtocol} provider.
  319. @param clientTransportFactory: a callable that produces the transport which
  320. will be attached to the protocol returned from C{ClientClass}.
  321. @type clientTransportFactory: callable taking (L{IProtocol}) and returning
  322. L{FakeTransport}
  323. @param serverTransportFactory: a callable that produces the transport which
  324. will be attached to the protocol returned from C{ServerClass}.
  325. @type serverTransportFactory: callable taking (L{IProtocol}) and returning
  326. L{FakeTransport}
  327. @param debug: Should this dump an escaped version of all traffic on this
  328. connection to stdout for inspection?
  329. @type debug: L{bool}
  330. @param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
  331. returning to put the protocols into their post-handshake or
  332. post-server-greeting state?
  333. @type greet: L{bool}
  334. @return: the client protocol, the server protocol, and an L{IOPump} which,
  335. when its C{pump} and C{flush} methods are called, will move data
  336. between the created client and server protocol instances.
  337. @rtype: 3-L{tuple} of L{IProtocol}, L{IProtocol}, L{IOPump}
  338. """
  339. c = ClientClass()
  340. s = ServerClass()
  341. cio = clientTransportFactory(c)
  342. sio = serverTransportFactory(s)
  343. return c, s, connect(s, sio, c, cio, debug, greet)
  344. def _factoriesShouldConnect(clientInfo, serverInfo):
  345. """
  346. Should the client and server described by the arguments be connected to
  347. each other, i.e. do their port numbers match?
  348. @param clientInfo: the args for connectTCP
  349. @type clientInfo: L{tuple}
  350. @param serverInfo: the args for listenTCP
  351. @type serverInfo: L{tuple}
  352. @return: If they do match, return factories for the client and server that
  353. should connect; otherwise return L{None}, indicating they shouldn't be
  354. connected.
  355. @rtype: L{None} or 2-L{tuple} of (L{ClientFactory},
  356. L{IProtocolFactory})
  357. """
  358. (clientHost, clientPort, clientFactory, clientTimeout,
  359. clientBindAddress) = clientInfo
  360. (serverPort, serverFactory, serverBacklog,
  361. serverInterface) = serverInfo
  362. if serverPort == clientPort:
  363. return clientFactory, serverFactory
  364. else:
  365. return None
  366. class ConnectionCompleter(object):
  367. """
  368. A L{ConnectionCompleter} can cause synthetic TCP connections established by
  369. L{MemoryReactor.connectTCP} and L{MemoryReactor.listenTCP} to succeed or
  370. fail.
  371. """
  372. def __init__(self, memoryReactor):
  373. """
  374. Create a L{ConnectionCompleter} from a L{MemoryReactor}.
  375. @param memoryReactor: The reactor to attach to.
  376. @type memoryReactor: L{MemoryReactor}
  377. """
  378. self._reactor = memoryReactor
  379. def succeedOnce(self, debug=False):
  380. """
  381. Complete a single TCP connection established on this
  382. L{ConnectionCompleter}'s L{MemoryReactor}.
  383. @param debug: A flag; whether to dump output from the established
  384. connection to stdout.
  385. @type debug: L{bool}
  386. @return: a pump for the connection, or L{None} if no connection could
  387. be established.
  388. @rtype: L{IOPump} or L{None}
  389. """
  390. memoryReactor = self._reactor
  391. for clientIdx, clientInfo in enumerate(memoryReactor.tcpClients):
  392. for serverInfo in memoryReactor.tcpServers:
  393. factories = _factoriesShouldConnect(clientInfo, serverInfo)
  394. if factories:
  395. memoryReactor.tcpClients.remove(clientInfo)
  396. memoryReactor.connectors.pop(clientIdx)
  397. clientFactory, serverFactory = factories
  398. clientProtocol = clientFactory.buildProtocol(None)
  399. serverProtocol = serverFactory.buildProtocol(None)
  400. serverTransport = makeFakeServer(serverProtocol)
  401. clientTransport = makeFakeClient(clientProtocol)
  402. return connect(serverProtocol, serverTransport,
  403. clientProtocol, clientTransport,
  404. debug)
  405. def failOnce(self, reason=Failure(ConnectionRefusedError())):
  406. """
  407. Fail a single TCP connection established on this
  408. L{ConnectionCompleter}'s L{MemoryReactor}.
  409. @param reason: the reason to provide that the connection failed.
  410. @type reason: L{Failure}
  411. """
  412. self._reactor.tcpClients.pop(0)[2].clientConnectionFailed(
  413. self._reactor.connectors.pop(0), reason
  414. )
  415. def connectableEndpoint(debug=False):
  416. """
  417. Create an endpoint that can be fired on demand.
  418. @param debug: A flag; whether to dump output from the established
  419. connection to stdout.
  420. @type debug: L{bool}
  421. @return: A client endpoint, and an object that will cause one of the
  422. L{Deferred}s returned by that client endpoint.
  423. @rtype: 2-L{tuple} of (L{IStreamClientEndpoint}, L{ConnectionCompleter})
  424. """
  425. reactor = MemoryReactorClock()
  426. clientEndpoint = TCP4ClientEndpoint(reactor, "0.0.0.0", 4321)
  427. serverEndpoint = TCP4ServerEndpoint(reactor, 4321)
  428. serverEndpoint.listen(Factory.forProtocol(Protocol))
  429. return clientEndpoint, ConnectionCompleter(reactor)