You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

styles.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # -*- test-case-name: twisted.test.test_persisted -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Different styles of persisted objects.
  6. """
  7. from __future__ import division, absolute_import
  8. # System Imports
  9. import types
  10. import pickle
  11. try:
  12. import copy_reg
  13. except ImportError:
  14. import copyreg as copy_reg
  15. import copy
  16. import inspect
  17. from twisted.python.compat import _PY3, _PYPY
  18. # Twisted Imports
  19. from twisted.python import log
  20. from twisted.python import reflect
  21. oldModules = {}
  22. try:
  23. import cPickle
  24. except ImportError:
  25. cPickle = None
  26. if cPickle is None or cPickle.PicklingError is pickle.PicklingError:
  27. _UniversalPicklingError = pickle.PicklingError
  28. else:
  29. class _UniversalPicklingError(pickle.PicklingError,
  30. cPickle.PicklingError):
  31. """
  32. A PicklingError catchable by both L{cPickle.PicklingError} and
  33. L{pickle.PicklingError} handlers.
  34. """
  35. ## First, let's register support for some stuff that really ought to
  36. ## be registerable...
  37. def pickleMethod(method):
  38. 'support function for copy_reg to pickle method refs'
  39. if _PY3:
  40. return (unpickleMethod, (method.__name__,
  41. method.__self__,
  42. method.__self__.__class__))
  43. else:
  44. return (unpickleMethod, (method.im_func.__name__,
  45. method.im_self,
  46. method.im_class))
  47. def _methodFunction(classObject, methodName):
  48. """
  49. Retrieve the function object implementing a method name given the class
  50. it's on and a method name.
  51. @param classObject: A class to retrieve the method's function from.
  52. @type classObject: L{type} or L{types.ClassType}
  53. @param methodName: The name of the method whose function to retrieve.
  54. @type methodName: native L{str}
  55. @return: the function object corresponding to the given method name.
  56. @rtype: L{types.FunctionType}
  57. """
  58. methodObject = getattr(classObject, methodName)
  59. if _PY3:
  60. return methodObject
  61. return methodObject.im_func
  62. def unpickleMethod(im_name, im_self, im_class):
  63. """
  64. Support function for copy_reg to unpickle method refs.
  65. @param im_name: The name of the method.
  66. @type im_name: native L{str}
  67. @param im_self: The instance that the method was present on.
  68. @type im_self: L{object}
  69. @param im_class: The class where the method was declared.
  70. @type im_class: L{types.ClassType} or L{type} or L{None}
  71. """
  72. if im_self is None:
  73. return getattr(im_class, im_name)
  74. try:
  75. methodFunction = _methodFunction(im_class, im_name)
  76. except AttributeError:
  77. log.msg("Method", im_name, "not on class", im_class)
  78. assert im_self is not None, "No recourse: no instance to guess from."
  79. # Attempt a last-ditch fix before giving up. If classes have changed
  80. # around since we pickled this method, we may still be able to get it
  81. # by looking on the instance's current class.
  82. if im_self.__class__ is im_class:
  83. raise
  84. return unpickleMethod(im_name, im_self, im_self.__class__)
  85. else:
  86. if _PY3:
  87. maybeClass = ()
  88. else:
  89. maybeClass = tuple([im_class])
  90. bound = types.MethodType(methodFunction, im_self, *maybeClass)
  91. return bound
  92. copy_reg.pickle(types.MethodType, pickleMethod, unpickleMethod)
  93. def _pickleFunction(f):
  94. """
  95. Reduce, in the sense of L{pickle}'s C{object.__reduce__} special method, a
  96. function object into its constituent parts.
  97. @param f: The function to reduce.
  98. @type f: L{types.FunctionType}
  99. @return: a 2-tuple of a reference to L{_unpickleFunction} and a tuple of
  100. its arguments, a 1-tuple of the function's fully qualified name.
  101. @rtype: 2-tuple of C{callable, native string}
  102. """
  103. if f.__name__ == '<lambda>':
  104. raise _UniversalPicklingError(
  105. "Cannot pickle lambda function: {}".format(f))
  106. return (_unpickleFunction,
  107. tuple([".".join([f.__module__, f.__qualname__])]))
  108. def _unpickleFunction(fullyQualifiedName):
  109. """
  110. Convert a function name into a function by importing it.
  111. This is a synonym for L{twisted.python.reflect.namedAny}, but imported
  112. locally to avoid circular imports, and also to provide a persistent name
  113. that can be stored (and deprecated) independently of C{namedAny}.
  114. @param fullyQualifiedName: The fully qualified name of a function.
  115. @type fullyQualifiedName: native C{str}
  116. @return: A function object imported from the given location.
  117. @rtype: L{types.FunctionType}
  118. """
  119. from twisted.python.reflect import namedAny
  120. return namedAny(fullyQualifiedName)
  121. copy_reg.pickle(types.FunctionType, _pickleFunction, _unpickleFunction)
  122. def pickleModule(module):
  123. 'support function for copy_reg to pickle module refs'
  124. return unpickleModule, (module.__name__,)
  125. def unpickleModule(name):
  126. 'support function for copy_reg to unpickle module refs'
  127. if name in oldModules:
  128. log.msg("Module has moved: %s" % name)
  129. name = oldModules[name]
  130. log.msg(name)
  131. return __import__(name,{},{},'x')
  132. copy_reg.pickle(types.ModuleType,
  133. pickleModule,
  134. unpickleModule)
  135. def pickleStringO(stringo):
  136. """
  137. Reduce the given cStringO.
  138. This is only called on Python 2, because the cStringIO module only exists
  139. on Python 2.
  140. @param stringo: The string output to pickle.
  141. @type stringo: L{cStringIO.OutputType}
  142. """
  143. 'support function for copy_reg to pickle StringIO.OutputTypes'
  144. return unpickleStringO, (stringo.getvalue(), stringo.tell())
  145. def unpickleStringO(val, sek):
  146. """
  147. Convert the output of L{pickleStringO} into an appropriate type for the
  148. current python version. This may be called on Python 3 and will convert a
  149. cStringIO into an L{io.StringIO}.
  150. @param val: The content of the file.
  151. @type val: L{bytes}
  152. @param sek: The seek position of the file.
  153. @type sek: L{int}
  154. @return: a file-like object which you can write bytes to.
  155. @rtype: L{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
  156. """
  157. x = _cStringIO()
  158. x.write(val)
  159. x.seek(sek)
  160. return x
  161. def pickleStringI(stringi):
  162. """
  163. Reduce the given cStringI.
  164. This is only called on Python 2, because the cStringIO module only exists
  165. on Python 2.
  166. @param stringi: The string input to pickle.
  167. @type stringi: L{cStringIO.InputType}
  168. @return: a 2-tuple of (C{unpickleStringI}, (bytes, pointer))
  169. @rtype: 2-tuple of (function, (bytes, int))
  170. """
  171. return unpickleStringI, (stringi.getvalue(), stringi.tell())
  172. def unpickleStringI(val, sek):
  173. """
  174. Convert the output of L{pickleStringI} into an appropriate type for the
  175. current Python version.
  176. This may be called on Python 3 and will convert a cStringIO into an
  177. L{io.StringIO}.
  178. @param val: The content of the file.
  179. @type val: L{bytes}
  180. @param sek: The seek position of the file.
  181. @type sek: L{int}
  182. @return: a file-like object which you can read bytes from.
  183. @rtype: L{cStringIO.OutputType} on Python 2, L{io.StringIO} on Python 3.
  184. """
  185. x = _cStringIO(val)
  186. x.seek(sek)
  187. return x
  188. try:
  189. from cStringIO import InputType, OutputType, StringIO as _cStringIO
  190. except ImportError:
  191. from io import StringIO as _cStringIO
  192. else:
  193. copy_reg.pickle(OutputType, pickleStringO, unpickleStringO)
  194. copy_reg.pickle(InputType, pickleStringI, unpickleStringI)
  195. class Ephemeral:
  196. """
  197. This type of object is never persisted; if possible, even references to it
  198. are eliminated.
  199. """
  200. def __reduce__(self):
  201. """
  202. Serialize any subclass of L{Ephemeral} in a way which replaces it with
  203. L{Ephemeral} itself.
  204. """
  205. return (Ephemeral, ())
  206. def __getstate__(self):
  207. log.msg( "WARNING: serializing ephemeral %s" % self )
  208. if not _PYPY:
  209. import gc
  210. if getattr(gc, 'get_referrers', None):
  211. for r in gc.get_referrers(self):
  212. log.msg( " referred to by %s" % (r,))
  213. return None
  214. def __setstate__(self, state):
  215. log.msg( "WARNING: unserializing ephemeral %s" % self.__class__ )
  216. self.__class__ = Ephemeral
  217. versionedsToUpgrade = {}
  218. upgraded = {}
  219. def doUpgrade():
  220. global versionedsToUpgrade, upgraded
  221. for versioned in list(versionedsToUpgrade.values()):
  222. requireUpgrade(versioned)
  223. versionedsToUpgrade = {}
  224. upgraded = {}
  225. def requireUpgrade(obj):
  226. """Require that a Versioned instance be upgraded completely first.
  227. """
  228. objID = id(obj)
  229. if objID in versionedsToUpgrade and objID not in upgraded:
  230. upgraded[objID] = 1
  231. obj.versionUpgrade()
  232. return obj
  233. def _aybabtu(c):
  234. """
  235. Get all of the parent classes of C{c}, not including C{c} itself, which are
  236. strict subclasses of L{Versioned}.
  237. @param c: a class
  238. @returns: list of classes
  239. """
  240. # begin with two classes that should *not* be included in the
  241. # final result
  242. l = [c, Versioned]
  243. for b in inspect.getmro(c):
  244. if b not in l and issubclass(b, Versioned):
  245. l.append(b)
  246. # return all except the unwanted classes
  247. return l[2:]
  248. class Versioned:
  249. """
  250. This type of object is persisted with versioning information.
  251. I have a single class attribute, the int persistenceVersion. After I am
  252. unserialized (and styles.doUpgrade() is called), self.upgradeToVersionX()
  253. will be called for each version upgrade I must undergo.
  254. For example, if I serialize an instance of a Foo(Versioned) at version 4
  255. and then unserialize it when the code is at version 9, the calls::
  256. self.upgradeToVersion5()
  257. self.upgradeToVersion6()
  258. self.upgradeToVersion7()
  259. self.upgradeToVersion8()
  260. self.upgradeToVersion9()
  261. will be made. If any of these methods are undefined, a warning message
  262. will be printed.
  263. """
  264. persistenceVersion = 0
  265. persistenceForgets = ()
  266. def __setstate__(self, state):
  267. versionedsToUpgrade[id(self)] = self
  268. self.__dict__ = state
  269. def __getstate__(self, dict=None):
  270. """Get state, adding a version number to it on its way out.
  271. """
  272. dct = copy.copy(dict or self.__dict__)
  273. bases = _aybabtu(self.__class__)
  274. bases.reverse()
  275. bases.append(self.__class__) # don't forget me!!
  276. for base in bases:
  277. if 'persistenceForgets' in base.__dict__:
  278. for slot in base.persistenceForgets:
  279. if slot in dct:
  280. del dct[slot]
  281. if 'persistenceVersion' in base.__dict__:
  282. dct['%s.persistenceVersion' % reflect.qual(base)] = base.persistenceVersion
  283. return dct
  284. def versionUpgrade(self):
  285. """(internal) Do a version upgrade.
  286. """
  287. bases = _aybabtu(self.__class__)
  288. # put the bases in order so superclasses' persistenceVersion methods
  289. # will be called first.
  290. bases.reverse()
  291. bases.append(self.__class__) # don't forget me!!
  292. # first let's look for old-skool versioned's
  293. if "persistenceVersion" in self.__dict__:
  294. # Hacky heuristic: if more than one class subclasses Versioned,
  295. # we'll assume that the higher version number wins for the older
  296. # class, so we'll consider the attribute the version of the older
  297. # class. There are obviously possibly times when this will
  298. # eventually be an incorrect assumption, but hopefully old-school
  299. # persistenceVersion stuff won't make it that far into multiple
  300. # classes inheriting from Versioned.
  301. pver = self.__dict__['persistenceVersion']
  302. del self.__dict__['persistenceVersion']
  303. highestVersion = 0
  304. highestBase = None
  305. for base in bases:
  306. if 'persistenceVersion' not in base.__dict__:
  307. continue
  308. if base.persistenceVersion > highestVersion:
  309. highestBase = base
  310. highestVersion = base.persistenceVersion
  311. if highestBase:
  312. self.__dict__['%s.persistenceVersion' % reflect.qual(highestBase)] = pver
  313. for base in bases:
  314. # ugly hack, but it's what the user expects, really
  315. if (Versioned not in base.__bases__ and
  316. 'persistenceVersion' not in base.__dict__):
  317. continue
  318. currentVers = base.persistenceVersion
  319. pverName = '%s.persistenceVersion' % reflect.qual(base)
  320. persistVers = (self.__dict__.get(pverName) or 0)
  321. if persistVers:
  322. del self.__dict__[pverName]
  323. assert persistVers <= currentVers, "Sorry, can't go backwards in time."
  324. while persistVers < currentVers:
  325. persistVers = persistVers + 1
  326. method = base.__dict__.get('upgradeToVersion%s' % persistVers, None)
  327. if method:
  328. log.msg( "Upgrading %s (of %s @ %s) to version %s" % (reflect.qual(base), reflect.qual(self.__class__), id(self), persistVers) )
  329. method(self)
  330. else:
  331. log.msg( 'Warning: cannot upgrade %s to version %s' % (base, persistVers) )