123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697 |
- # -*- test-case-name: twisted.test.test_policies -*-
- # Copyright (c) Twisted Matrix Laboratories.
- # See LICENSE for details.
-
- """
- Resource limiting policies.
-
- @seealso: See also L{twisted.protocols.htb} for rate limiting.
- """
-
-
- # system imports
- import sys
- from typing import Optional, Type
-
- from zope.interface import directlyProvides, providedBy
-
- from twisted.internet import error, interfaces
- from twisted.internet.interfaces import ILoggingContext
-
- # twisted imports
- from twisted.internet.protocol import ClientFactory, Protocol, ServerFactory
- from twisted.python import log
-
-
- def _wrappedLogPrefix(wrapper, wrapped):
- """
- Compute a log prefix for a wrapper and the object it wraps.
-
- @rtype: C{str}
- """
- if ILoggingContext.providedBy(wrapped):
- logPrefix = wrapped.logPrefix()
- else:
- logPrefix = wrapped.__class__.__name__
- return f"{logPrefix} ({wrapper.__class__.__name__})"
-
-
- class ProtocolWrapper(Protocol):
- """
- Wraps protocol instances and acts as their transport as well.
-
- @ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>}
- provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>}
- method calls onto this L{ProtocolWrapper} will be proxied.
-
- @ivar factory: The L{WrappingFactory} which created this
- L{ProtocolWrapper}.
- """
-
- disconnecting = 0
-
- def __init__(
- self, factory: "WrappingFactory", wrappedProtocol: interfaces.IProtocol
- ):
- self.wrappedProtocol = wrappedProtocol
- self.factory = factory
-
- def logPrefix(self):
- """
- Use a customized log prefix mentioning both the wrapped protocol and
- the current one.
- """
- return _wrappedLogPrefix(self, self.wrappedProtocol)
-
- def makeConnection(self, transport):
- """
- When a connection is made, register this wrapper with its factory,
- save the real transport, and connect the wrapped protocol to this
- L{ProtocolWrapper} to intercept any transport calls it makes.
- """
- directlyProvides(self, providedBy(transport))
- Protocol.makeConnection(self, transport)
- self.factory.registerProtocol(self)
- self.wrappedProtocol.makeConnection(self)
-
- # Transport relaying
-
- def write(self, data):
- self.transport.write(data)
-
- def writeSequence(self, data):
- self.transport.writeSequence(data)
-
- def loseConnection(self):
- self.disconnecting = 1
- self.transport.loseConnection()
-
- def getPeer(self):
- return self.transport.getPeer()
-
- def getHost(self):
- return self.transport.getHost()
-
- def registerProducer(self, producer, streaming):
- self.transport.registerProducer(producer, streaming)
-
- def unregisterProducer(self):
- self.transport.unregisterProducer()
-
- def stopConsuming(self):
- self.transport.stopConsuming()
-
- def __getattr__(self, name):
- return getattr(self.transport, name)
-
- # Protocol relaying
-
- def dataReceived(self, data):
- self.wrappedProtocol.dataReceived(data)
-
- def connectionLost(self, reason):
- self.factory.unregisterProtocol(self)
- self.wrappedProtocol.connectionLost(reason)
-
- # Breaking reference cycle between self and wrappedProtocol.
- self.wrappedProtocol = None
-
-
- class WrappingFactory(ClientFactory):
- """
- Wraps a factory and its protocols, and keeps track of them.
- """
-
- protocol: Type[Protocol] = ProtocolWrapper
-
- def __init__(self, wrappedFactory):
- self.wrappedFactory = wrappedFactory
- self.protocols = {}
-
- def logPrefix(self):
- """
- Generate a log prefix mentioning both the wrapped factory and this one.
- """
- return _wrappedLogPrefix(self, self.wrappedFactory)
-
- def doStart(self):
- self.wrappedFactory.doStart()
- ClientFactory.doStart(self)
-
- def doStop(self):
- self.wrappedFactory.doStop()
- ClientFactory.doStop(self)
-
- def startedConnecting(self, connector):
- self.wrappedFactory.startedConnecting(connector)
-
- def clientConnectionFailed(self, connector, reason):
- self.wrappedFactory.clientConnectionFailed(connector, reason)
-
- def clientConnectionLost(self, connector, reason):
- self.wrappedFactory.clientConnectionLost(connector, reason)
-
- def buildProtocol(self, addr):
- return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
-
- def registerProtocol(self, p):
- """
- Called by protocol to register itself.
- """
- self.protocols[p] = 1
-
- def unregisterProtocol(self, p):
- """
- Called by protocols when they go away.
- """
- del self.protocols[p]
-
-
- class ThrottlingProtocol(ProtocolWrapper):
- """
- Protocol for L{ThrottlingFactory}.
- """
-
- # wrap API for tracking bandwidth
-
- def write(self, data):
- self.factory.registerWritten(len(data))
- ProtocolWrapper.write(self, data)
-
- def writeSequence(self, seq):
- self.factory.registerWritten(sum(map(len, seq)))
- ProtocolWrapper.writeSequence(self, seq)
-
- def dataReceived(self, data):
- self.factory.registerRead(len(data))
- ProtocolWrapper.dataReceived(self, data)
-
- def registerProducer(self, producer, streaming):
- self.producer = producer
- ProtocolWrapper.registerProducer(self, producer, streaming)
-
- def unregisterProducer(self):
- del self.producer
- ProtocolWrapper.unregisterProducer(self)
-
- def throttleReads(self):
- self.transport.pauseProducing()
-
- def unthrottleReads(self):
- self.transport.resumeProducing()
-
- def throttleWrites(self):
- if hasattr(self, "producer"):
- self.producer.pauseProducing()
-
- def unthrottleWrites(self):
- if hasattr(self, "producer"):
- self.producer.resumeProducing()
-
-
- class ThrottlingFactory(WrappingFactory):
- """
- Throttles bandwidth and number of connections.
-
- Write bandwidth will only be throttled if there is a producer
- registered.
- """
-
- protocol = ThrottlingProtocol
-
- def __init__(
- self,
- wrappedFactory,
- maxConnectionCount=sys.maxsize,
- readLimit=None,
- writeLimit=None,
- ):
- WrappingFactory.__init__(self, wrappedFactory)
- self.connectionCount = 0
- self.maxConnectionCount = maxConnectionCount
- self.readLimit = readLimit # max bytes we should read per second
- self.writeLimit = writeLimit # max bytes we should write per second
- self.readThisSecond = 0
- self.writtenThisSecond = 0
- self.unthrottleReadsID = None
- self.checkReadBandwidthID = None
- self.unthrottleWritesID = None
- self.checkWriteBandwidthID = None
-
- def callLater(self, period, func):
- """
- Wrapper around
- L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
- for test purpose.
- """
- from twisted.internet import reactor
-
- return reactor.callLater(period, func)
-
- def registerWritten(self, length):
- """
- Called by protocol to tell us more bytes were written.
- """
- self.writtenThisSecond += length
-
- def registerRead(self, length):
- """
- Called by protocol to tell us more bytes were read.
- """
- self.readThisSecond += length
-
- def checkReadBandwidth(self):
- """
- Checks if we've passed bandwidth limits.
- """
- if self.readThisSecond > self.readLimit:
- self.throttleReads()
- throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
- self.unthrottleReadsID = self.callLater(throttleTime, self.unthrottleReads)
- self.readThisSecond = 0
- self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
-
- def checkWriteBandwidth(self):
- if self.writtenThisSecond > self.writeLimit:
- self.throttleWrites()
- throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
- self.unthrottleWritesID = self.callLater(
- throttleTime, self.unthrottleWrites
- )
- # reset for next round
- self.writtenThisSecond = 0
- self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
-
- def throttleReads(self):
- """
- Throttle reads on all protocols.
- """
- log.msg("Throttling reads on %s" % self)
- for p in self.protocols.keys():
- p.throttleReads()
-
- def unthrottleReads(self):
- """
- Stop throttling reads on all protocols.
- """
- self.unthrottleReadsID = None
- log.msg("Stopped throttling reads on %s" % self)
- for p in self.protocols.keys():
- p.unthrottleReads()
-
- def throttleWrites(self):
- """
- Throttle writes on all protocols.
- """
- log.msg("Throttling writes on %s" % self)
- for p in self.protocols.keys():
- p.throttleWrites()
-
- def unthrottleWrites(self):
- """
- Stop throttling writes on all protocols.
- """
- self.unthrottleWritesID = None
- log.msg("Stopped throttling writes on %s" % self)
- for p in self.protocols.keys():
- p.unthrottleWrites()
-
- def buildProtocol(self, addr):
- if self.connectionCount == 0:
- if self.readLimit is not None:
- self.checkReadBandwidth()
- if self.writeLimit is not None:
- self.checkWriteBandwidth()
-
- if self.connectionCount < self.maxConnectionCount:
- self.connectionCount += 1
- return WrappingFactory.buildProtocol(self, addr)
- else:
- log.msg("Max connection count reached!")
- return None
-
- def unregisterProtocol(self, p):
- WrappingFactory.unregisterProtocol(self, p)
- self.connectionCount -= 1
- if self.connectionCount == 0:
- if self.unthrottleReadsID is not None:
- self.unthrottleReadsID.cancel()
- if self.checkReadBandwidthID is not None:
- self.checkReadBandwidthID.cancel()
- if self.unthrottleWritesID is not None:
- self.unthrottleWritesID.cancel()
- if self.checkWriteBandwidthID is not None:
- self.checkWriteBandwidthID.cancel()
-
-
- class SpewingProtocol(ProtocolWrapper):
- def dataReceived(self, data):
- log.msg("Received: %r" % data)
- ProtocolWrapper.dataReceived(self, data)
-
- def write(self, data):
- log.msg("Sending: %r" % data)
- ProtocolWrapper.write(self, data)
-
-
- class SpewingFactory(WrappingFactory):
- protocol = SpewingProtocol
-
-
- class LimitConnectionsByPeer(WrappingFactory):
-
- maxConnectionsPerPeer = 5
-
- def startFactory(self):
- self.peerConnections = {}
-
- def buildProtocol(self, addr):
- peerHost = addr[0]
- connectionCount = self.peerConnections.get(peerHost, 0)
- if connectionCount >= self.maxConnectionsPerPeer:
- return None
- self.peerConnections[peerHost] = connectionCount + 1
- return WrappingFactory.buildProtocol(self, addr)
-
- def unregisterProtocol(self, p):
- peerHost = p.getPeer()[1]
- self.peerConnections[peerHost] -= 1
- if self.peerConnections[peerHost] == 0:
- del self.peerConnections[peerHost]
-
-
- class LimitTotalConnectionsFactory(ServerFactory):
- """
- Factory that limits the number of simultaneous connections.
-
- @type connectionCount: C{int}
- @ivar connectionCount: number of current connections.
- @type connectionLimit: C{int} or L{None}
- @cvar connectionLimit: maximum number of connections.
- @type overflowProtocol: L{Protocol} or L{None}
- @cvar overflowProtocol: Protocol to use for new connections when
- connectionLimit is exceeded. If L{None} (the default value), excess
- connections will be closed immediately.
- """
-
- connectionCount = 0
- connectionLimit = None
- overflowProtocol: Optional[Type[Protocol]] = None
-
- def buildProtocol(self, addr):
- if self.connectionLimit is None or self.connectionCount < self.connectionLimit:
- # Build the normal protocol
- wrappedProtocol = self.protocol()
- elif self.overflowProtocol is None:
- # Just drop the connection
- return None
- else:
- # Too many connections, so build the overflow protocol
- wrappedProtocol = self.overflowProtocol()
-
- wrappedProtocol.factory = self
- protocol = ProtocolWrapper(self, wrappedProtocol)
- self.connectionCount += 1
- return protocol
-
- def registerProtocol(self, p):
- pass
-
- def unregisterProtocol(self, p):
- self.connectionCount -= 1
-
-
- class TimeoutProtocol(ProtocolWrapper):
- """
- Protocol that automatically disconnects when the connection is idle.
- """
-
- def __init__(self, factory, wrappedProtocol, timeoutPeriod):
- """
- Constructor.
-
- @param factory: An L{TimeoutFactory}.
- @param wrappedProtocol: A L{Protocol} to wrapp.
- @param timeoutPeriod: Number of seconds to wait for activity before
- timing out.
- """
- ProtocolWrapper.__init__(self, factory, wrappedProtocol)
- self.timeoutCall = None
- self.timeoutPeriod = None
- self.setTimeout(timeoutPeriod)
-
- def setTimeout(self, timeoutPeriod=None):
- """
- Set a timeout.
-
- This will cancel any existing timeouts.
-
- @param timeoutPeriod: If not L{None}, change the timeout period.
- Otherwise, use the existing value.
- """
- self.cancelTimeout()
- self.timeoutPeriod = timeoutPeriod
- if timeoutPeriod is not None:
- self.timeoutCall = self.factory.callLater(
- self.timeoutPeriod, self.timeoutFunc
- )
-
- def cancelTimeout(self):
- """
- Cancel the timeout.
-
- If the timeout was already cancelled, this does nothing.
- """
- self.timeoutPeriod = None
- if self.timeoutCall:
- try:
- self.timeoutCall.cancel()
- except (error.AlreadyCalled, error.AlreadyCancelled):
- pass
- self.timeoutCall = None
-
- def resetTimeout(self):
- """
- Reset the timeout, usually because some activity just happened.
- """
- if self.timeoutCall:
- self.timeoutCall.reset(self.timeoutPeriod)
-
- def write(self, data):
- self.resetTimeout()
- ProtocolWrapper.write(self, data)
-
- def writeSequence(self, seq):
- self.resetTimeout()
- ProtocolWrapper.writeSequence(self, seq)
-
- def dataReceived(self, data):
- self.resetTimeout()
- ProtocolWrapper.dataReceived(self, data)
-
- def connectionLost(self, reason):
- self.cancelTimeout()
- ProtocolWrapper.connectionLost(self, reason)
-
- def timeoutFunc(self):
- """
- This method is called when the timeout is triggered.
-
- By default it calls I{loseConnection}. Override this if you want
- something else to happen.
- """
- self.loseConnection()
-
-
- class TimeoutFactory(WrappingFactory):
- """
- Factory for TimeoutWrapper.
- """
-
- protocol = TimeoutProtocol
-
- def __init__(self, wrappedFactory, timeoutPeriod=30 * 60):
- self.timeoutPeriod = timeoutPeriod
- WrappingFactory.__init__(self, wrappedFactory)
-
- def buildProtocol(self, addr):
- return self.protocol(
- self,
- self.wrappedFactory.buildProtocol(addr),
- timeoutPeriod=self.timeoutPeriod,
- )
-
- def callLater(self, period, func):
- """
- Wrapper around
- L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
- for test purpose.
- """
- from twisted.internet import reactor
-
- return reactor.callLater(period, func)
-
-
- class TrafficLoggingProtocol(ProtocolWrapper):
- def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None, number=0):
- """
- @param factory: factory which created this protocol.
- @type factory: L{protocol.Factory}.
- @param wrappedProtocol: the underlying protocol.
- @type wrappedProtocol: C{protocol.Protocol}.
- @param logfile: file opened for writing used to write log messages.
- @type logfile: C{file}
- @param lengthLimit: maximum size of the datareceived logged.
- @type lengthLimit: C{int}
- @param number: identifier of the connection.
- @type number: C{int}.
- """
- ProtocolWrapper.__init__(self, factory, wrappedProtocol)
- self.logfile = logfile
- self.lengthLimit = lengthLimit
- self._number = number
-
- def _log(self, line):
- self.logfile.write(line + "\n")
- self.logfile.flush()
-
- def _mungeData(self, data):
- if self.lengthLimit and len(data) > self.lengthLimit:
- data = data[: self.lengthLimit - 12] + "<... elided>"
- return data
-
- # IProtocol
- def connectionMade(self):
- self._log("*")
- return ProtocolWrapper.connectionMade(self)
-
- def dataReceived(self, data):
- self._log("C %d: %r" % (self._number, self._mungeData(data)))
- return ProtocolWrapper.dataReceived(self, data)
-
- def connectionLost(self, reason):
- self._log("C %d: %r" % (self._number, reason))
- return ProtocolWrapper.connectionLost(self, reason)
-
- # ITransport
- def write(self, data):
- self._log("S %d: %r" % (self._number, self._mungeData(data)))
- return ProtocolWrapper.write(self, data)
-
- def writeSequence(self, iovec):
- self._log("SV %d: %r" % (self._number, [self._mungeData(d) for d in iovec]))
- return ProtocolWrapper.writeSequence(self, iovec)
-
- def loseConnection(self):
- self._log("S %d: *" % (self._number,))
- return ProtocolWrapper.loseConnection(self)
-
-
- class TrafficLoggingFactory(WrappingFactory):
- protocol = TrafficLoggingProtocol
-
- _counter = 0
-
- def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
- self.logfilePrefix = logfilePrefix
- self.lengthLimit = lengthLimit
- WrappingFactory.__init__(self, wrappedFactory)
-
- def open(self, name):
- return open(name, "w")
-
- def buildProtocol(self, addr):
- self._counter += 1
- logfile = self.open(self.logfilePrefix + "-" + str(self._counter))
- return self.protocol(
- self,
- self.wrappedFactory.buildProtocol(addr),
- logfile,
- self.lengthLimit,
- self._counter,
- )
-
- def resetCounter(self):
- """
- Reset the value of the counter used to identify connections.
- """
- self._counter = 0
-
-
- class TimeoutMixin:
- """
- Mixin for protocols which wish to timeout connections.
-
- Protocols that mix this in have a single timeout, set using L{setTimeout}.
- When the timeout is hit, L{timeoutConnection} is called, which, by
- default, closes the connection.
-
- @cvar timeOut: The number of seconds after which to timeout the connection.
- """
-
- timeOut: Optional[int] = None
-
- __timeoutCall = None
-
- def callLater(self, period, func):
- """
- Wrapper around
- L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
- for test purpose.
- """
- from twisted.internet import reactor
-
- return reactor.callLater(period, func)
-
- def resetTimeout(self):
- """
- Reset the timeout count down.
-
- If the connection has already timed out, then do nothing. If the
- timeout has been cancelled (probably using C{setTimeout(None)}), also
- do nothing.
-
- It's often a good idea to call this when the protocol has received
- some meaningful input from the other end of the connection. "I've got
- some data, they're still there, reset the timeout".
- """
- if self.__timeoutCall is not None and self.timeOut is not None:
- self.__timeoutCall.reset(self.timeOut)
-
- def setTimeout(self, period):
- """
- Change the timeout period
-
- @type period: C{int} or L{None}
- @param period: The period, in seconds, to change the timeout to, or
- L{None} to disable the timeout.
- """
- prev = self.timeOut
- self.timeOut = period
-
- if self.__timeoutCall is not None:
- if period is None:
- try:
- self.__timeoutCall.cancel()
- except (error.AlreadyCancelled, error.AlreadyCalled):
- # Do nothing if the call was already consumed.
- pass
- self.__timeoutCall = None
- else:
- self.__timeoutCall.reset(period)
- elif period is not None:
- self.__timeoutCall = self.callLater(period, self.__timedOut)
-
- return prev
-
- def __timedOut(self):
- self.__timeoutCall = None
- self.timeoutConnection()
-
- def timeoutConnection(self):
- """
- Called when the connection times out.
-
- Override to define behavior other than dropping the connection.
- """
- self.transport.loseConnection()
|