You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 8.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright (c) 2016 Anki, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License in the file LICENSE.txt or at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = []
  15. import threading
  16. import asyncio
  17. import concurrent.futures
  18. import functools
  19. import inspect
  20. import traceback
  21. import types
  22. class _MetaBase(type):
  23. '''Metaclass for all Cozmo package classes.
  24. Ensures that all *_factory class attributes are wrapped into a _Factory
  25. descriptor to automatically support synchronous operation.
  26. '''
  27. def __new__(mcs, name, bases, attrs, **kw):
  28. for k, v in attrs.items():
  29. if k.endswith('_factory'):
  30. # TODO: check type here too
  31. attrs[k] = _Factory(v)
  32. return super().__new__(mcs, name, bases, attrs, **kw)
  33. def __setattr__(cls, name, val):
  34. if name.endswith('_factory'):
  35. cls.__dict__[name].__set__(cls, val)
  36. else:
  37. super().__setattr__(name, val)
  38. class Base(metaclass=_MetaBase):
  39. '''Base class for Cozmo package objects.
  40. *_factory attributes are automatically wrapped into a _Factory descriptor to
  41. support synchronous operation.
  42. '''
  43. # used by SyncFatory
  44. _sync_thread_id = None
  45. _sync_abort_future = None
  46. def __init__(self, _sync_thread_id=None, _sync_abort_future=None, **kw):
  47. # machinery for SyncFactory
  48. if _sync_abort_future is not None:
  49. self._sync_thread_id = threading.get_ident()
  50. else:
  51. self._sync_thread_id = _sync_thread_id
  52. self._sync_abort_future = _sync_abort_future
  53. super().__init__(**kw)
  54. @property
  55. def loop(self):
  56. ''':class:`asyncio.BaseEventLoop`: loop instance that this object is registered with.'''
  57. return getattr(self, '_loop', None)
  58. class _Factory:
  59. '''Descriptor to wraps an object factory method.
  60. If the factory is called while the program is running in synchronous mode
  61. then the objects returned by the factory will be wrapped by a _SyncProxy
  62. object, which translates asynchronous responses to synchronous ones
  63. when made outside of the thread the top level object's event loop is running on.
  64. '''
  65. def __init__(self, factory):
  66. self._wrapped_factory = factory
  67. def __get__(self, ins, owner):
  68. sync_thread_id = getattr(ins, '_sync_thread_id', None)
  69. loop = getattr(ins, '_loop', None)
  70. if sync_thread_id:
  71. # Object instance is running in sync mode
  72. return _SyncFactory(self._wrapped_factory, loop, sync_thread_id, ins._sync_abort_future)
  73. # Pass through to the factory. Set loop here as a convenience as all
  74. # Cozmo objects require it by virtue of inheriting from event.Dispatcher
  75. return functools.partial(self._wrapped_factory, loop=loop)
  76. def __set__(self, ins, val):
  77. self._wrapped_factory = val
  78. def _SyncFactory(f, loop, thread_id, sync_abort_future):
  79. '''Instantiates a class by calling a factory function and then wrapping it with _SyncProxy'''
  80. def factory(*a, **kw):
  81. kw['_sync_thread_id'] = thread_id
  82. kw['_sync_abort_future'] = sync_abort_future
  83. if 'loop' not in kw:
  84. kw['loop'] = loop
  85. obj = f(*a, **kw)
  86. return _mkproxy(obj)
  87. return factory
  88. def _mkpt(cls, name):
  89. # create a passthru function
  90. f = getattr(cls, name)
  91. @functools.wraps(f)
  92. def pt(self, *a, **kw):
  93. wrap = self.__wrapped__
  94. f = object.__getattribute__(wrap, name)
  95. return f(*a, **kw)
  96. return pt
  97. class _SyncProxy:
  98. '''Wraps cozmo objects to provide synchronous access when required.
  99. Each method call and attribute access is passed through to the wrapped object.
  100. If the caller is operating in a different thread to the callee (for example, the
  101. caller is operating outside of the context of the event loop), then any
  102. calls to the wrapped object are dispatched to the event loop running on the
  103. loop's native thread.
  104. Returned co-routines functions and Futures are waited upon until completion.
  105. '''
  106. def __init__(self, wrapped):
  107. self.__wrapped__ = wrapped
  108. def __getattribute__(self, name):
  109. wrapped = object.__getattribute__(self, '__wrapped__')
  110. if name == '__wrapped__':
  111. return wrapped
  112. # if name points to a property, this will execute the property getter
  113. # and return the value, else returns the value according to usual
  114. # lookup rules.
  115. value = object.__getattribute__(wrapped, name)
  116. # determine whether the call is being invoked locally, from within the
  117. # event loop's native thread, or elsewhere (usually the main thread)
  118. thread_id = object.__getattribute__(wrapped, '_sync_thread_id')
  119. is_local_thread = thread_id is None or threading.get_ident() == thread_id
  120. if is_local_thread:
  121. # passthru/no-op if being called from the same thread as the object
  122. # was created from.
  123. return value
  124. if inspect.ismethod(value) and not asyncio.iscoroutinefunction(value):
  125. # Wrap the sync method into a coroutine that can be dispatched
  126. # from the same thread as the main event loop is running in
  127. f = value.__func__
  128. f = _to_coroutine(f)
  129. value = types.MethodType(f, wrapped)
  130. #value = types.MethodType(f, self)
  131. elif inspect.isfunction(value) and not asyncio.iscoroutinefunction(value):
  132. # Dispatch functions in the main event loop thread too
  133. value = _to_coroutine(value)
  134. if inspect.isawaitable(value):
  135. return _dispatch_coroutine(value, wrapped._loop, wrapped._sync_abort_future)
  136. elif asyncio.iscoroutinefunction(value):
  137. # Wrap coroutine into synchronous dispatch
  138. @functools.wraps(value)
  139. def wrap(*a, **kw):
  140. return _dispatch_coroutine(value(*a, **kw), wrapped._loop, wrapped._sync_abort_future)
  141. return wrap
  142. return value
  143. def __setattr__(self, name, value):
  144. if name == '__wrapped__':
  145. return super().__setattr__(name, value)
  146. wrapped = object.__getattribute__(self, '__wrapped__')
  147. return wrapped.__setattr__(name, value)
  148. def __repr__(self):
  149. wrapped = self.__wrapped__
  150. return "wrapped-" + object.__getattribute__(wrapped, '__repr__')()
  151. def _to_coroutine(f):
  152. @functools.wraps(f)
  153. async def wrap(*a, **kw):
  154. return f(*a, **kw)
  155. return wrap
  156. def _mkproxy(obj):
  157. '''Create a _SyncProxy for an object.'''
  158. # dynamically generate a class tailored for the wrapped object.
  159. d = {}
  160. cls = obj.__class__
  161. for name in dir(cls):
  162. if ((name.endswith('__') and name.startswith('__'))
  163. and name not in ('__class__', '__new__', '__init__', '__getattribute__', '__setattr__', '__repr__')):
  164. d[name] = _mkpt(cls, name)
  165. if hasattr(obj, '__aenter__'):
  166. d['__enter__'] = lambda self: self.__wrapper__.__aenter__()
  167. d['__exit__'] = lambda self, *a: self.__wrapper__.__aexit__(*a)
  168. cls = type("_proxy_"+obj.__class__.__name__, (_SyncProxy,), d)
  169. proxy = cls(obj)
  170. obj.__wrapper__ = proxy
  171. return proxy
  172. def _dispatch_coroutine(co, loop, abort_future):
  173. '''Execute a coroutine in a loop's thread and block till completion.
  174. Wraps a co-routine function; calling the function causes the co-routine
  175. to be dispatched in the event loop's thread and blocks until that call completes.
  176. Waits for either the coroutine or abort_future to complete.
  177. abort_future provides the main event loop with a means of triggering a
  178. clean shutdown in the case of an exception.
  179. '''
  180. fut = asyncio.run_coroutine_threadsafe(co, loop)
  181. result = concurrent.futures.wait((fut, abort_future), return_when=concurrent.futures.FIRST_COMPLETED)
  182. result = list(result.done)[0].result()
  183. if getattr(result, '__wrapped__', None) is None:
  184. # If the call retuned the wrapped contents of a _SyncProxy then return
  185. # the enclosing proxy instead to the sync caller
  186. wrapper = getattr(result, '__wrapper__', None)
  187. if wrapper is not None:
  188. result = wrapper
  189. return result