import abc
import types
from collections import defaultdict
from functools import partial
import asyncio
import logging
from typing import Union, Callable, Any, Dict
import aiohttp
from .tools import json
from .route import WebSocketRoute
class ClientException(Exception):
pass
class PingTimeoutError(Exception):
pass
def ping(obj, **kwargs):
return kwargs
log = logging.getLogger(__name__)
RouteType = Union[Callable[['WSRPCBase', Any], Any], WebSocketRoute]
class _ProxyMethod:
__slots__ = '__call', '__name'
def __init__(self, call_method, name):
self.__call = call_method
self.__name = name
def __call__(self, **kwargs):
return self.__call(self.__name, **kwargs)
def __getattr__(self, item: str):
return self.__class__(self.__call, ".".join((self.__name, item)))
class _Proxy:
__slots__ = '__call',
def __init__(self, call_method):
self.__call = call_method
def __getattr__(self, item: str):
return _ProxyMethod(self.__call, item)
[docs]class WSRPCBase:
""" Common WSRPC abstraction """
_ROUTES = defaultdict(lambda: {'ping': ping})
_CLIENTS = defaultdict(dict)
_CLEAN_LOCK_TIMEOUT = 2
__slots__ = ('_handlers', '_loop', '_pending_tasks', '_locks', '_futures', '_serial', '_timeout')
def __init__(self, loop: asyncio.AbstractEventLoop=None):
self._loop = loop or asyncio.get_event_loop()
self._handlers = {}
self._pending_tasks = set()
self._serial = 0
self._timeout = None
self._locks = defaultdict(partial(asyncio.Lock, loop=self._loop))
self._futures = defaultdict(self._loop.create_future)
def _create_task(self, coro):
task = self._loop.create_task(coro) # type: asyncio.Task
self._pending_tasks.add(task)
task.add_done_callback(partial(self._pending_tasks.remove))
return task
def _call_later(self, timer, callback, *args, **kwargs):
def handler():
self._create_task(asyncio.coroutine(callback)(*args, **kwargs))
self._pending_tasks.add(self._loop.call_later(timer, handler))
[docs] async def close(self):
""" Cancel all pending tasks """
async def task_waiter(task):
if not (hasattr(task, '__iter__') or hasattr(task, '__aiter__')):
return
try:
await task
except asyncio.CancelledError:
pass
except Exception:
log.exception("Unhandled exception when closing client connection")
for task in tuple(self._pending_tasks):
task.cancel()
if hasattr(task, 'cancelled') and not task.cancelled():
self._loop.create_task(task_waiter(task))
def _log_call(self, start: float, *args):
end = self._loop.time()
log.info(end - start)
async def _handle_message(self, msg: aiohttp.WSMessage):
if msg.type == aiohttp.WSMsgType.TEXT:
self._create_task(self.on_message(msg))
elif msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED):
self._create_task(self.close())
elif msg.type == aiohttp.WSMsgType.ERROR:
self._create_task(self.close())
raise aiohttp.WebSocketError(code=msg.type.value, message=msg)
else:
log.warning("Unhandled message %r %r", msg.type, msg.data)
@classmethod
def get_routes(cls) -> Dict[str, RouteType]:
return cls._ROUTES[cls]
@classmethod
def get_clients(cls) -> Dict[str, 'WSRPCBase']:
return cls._CLIENTS[cls]
@property
def routes(self) -> Dict[str, RouteType]:
""" Property which contains the socket routes """
return self.get_routes()
@property
def clients(self) -> Dict[str, 'WSRPCBase']:
""" Property which contains the socket clients """
return self.get_clients()
@staticmethod
def _prepare_args(args):
arguments = []
kwargs = {}
if isinstance(args, type(None)):
return arguments, kwargs
if isinstance(args, list):
arguments.extend(args)
elif isinstance(args, dict):
kwargs.update(args)
else:
arguments.append(args)
return arguments, kwargs
async def on_message(self, message: aiohttp.WSMessage):
# deserialize message
data = message.json(loads=json.loads)
serial = data.get('serial', -1)
msg_type = data.get('type', 'call')
assert serial >= 0
log.debug("Acquiring lock for %s serial %s", self, serial)
async with self._locks[serial]:
try:
if msg_type == 'call':
args, kwargs = self._prepare_args(data.get('arguments', None))
callback = data.get('call', None)
if callback is None:
raise ValueError('Require argument "call" does\'t exist.')
callee = self.resolver(callback)
calee_is_route = hasattr(callee, '__self__') and isinstance(callee.__self__, WebSocketRoute)
if not calee_is_route:
a = [self]
a.extend(args)
args = a
result = await self._executor(partial(callee, *args, **kwargs))
self._send(data=result, serial=serial, type='callback')
elif msg_type == 'callback':
cb = self._futures.pop(serial, None)
cb.set_result(data.get('data', None))
elif msg_type == 'error':
self._reject(data.get('serial', -1), data.get('data', None))
log.error('Client return error: \n\t{0}'.format(data.get('data', None)))
except Exception as e:
log.exception(e)
self._send(data=self._format_error(e), serial=serial, type='error')
finally:
def clean_lock():
log.debug("Release and delete lock for %s serial %s", self, serial)
if serial in self._locks:
self._locks.pop(serial)
self._call_later(self._CLEAN_LOCK_TIMEOUT, clean_lock)
@abc.abstractstaticmethod
def _send(self, **kwargs):
raise NotImplementedError
@staticmethod
def _format_error(e):
return {'type': str(type(e).__name__), 'message': str(e)}
def _reject(self, serial, error):
future = self._futures.get(serial)
if not future:
return
future.set_exception(ClientException(error))
def _unresolvable(self, func_name, *args, **kwargs):
raise NotImplementedError('Callback function "%r" not implemented' % func_name)
def resolver(self, func_name):
class_name, method = func_name.split('.') if '.' in func_name else (func_name, 'init')
callee = self.routes.get(class_name, self._unresolvable)
condition = (
callee == self._unresolvable or
isinstance(getattr(callee, '__self__', None), WebSocketRoute) or (
not isinstance(callee, (types.FunctionType, types.MethodType)) and issubclass(callee, WebSocketRoute)
)
)
if condition:
if class_name not in self._handlers:
self._handlers[class_name] = callee(self)
return self._handlers[class_name]._resolve(method) # noqa
callee = self.routes.get(func_name, self._unresolvable)
if hasattr(callee, '__call__'):
return callee
else:
raise NotImplementedError('Method call of {0} is not implemented'.format(repr(callee)))
def _get_serial(self):
self._serial += 2
return self._serial
[docs] def call(self, func: str, **kwargs):
""" Method for call remote function
Remote methods allows only kwargs as arguments.
You might use functions as route or classes
.. code-block:: python
async def remote_function(socket: WSRPCBase, *, foo, bar):
# call function from the client-side
await self.socket.proxy.ping()
return foo + bar
class RemoteClass(WebSocketRoute):
# this method executes when remote side call route name
asyc def init(self):
# call function from the client-side
await self.socket.proxy.ping()
async def make_something(self, foo, bar):
return foo + bar
"""
serial = self._get_serial()
future = self._futures[serial]
req_type = 'call'
send_future = self._send(serial=serial, type=req_type, call=func, arguments=kwargs)
log.info("Sending %r request #%d \"%s(%r)\" to the client.", req_type, serial, func, kwargs)
future = asyncio.ensure_future(asyncio.wait_for(future, self._timeout, loop=self._loop), loop=self._loop)
def propagate_exception(f):
if f.exception():
future.set_exception(f.exception())
if send_future:
send_future.add_done_callback(propagate_exception)
return future
[docs] @classmethod
def add_route(cls, route: str, handler: Union[WebSocketRoute, Callable]):
""" Expose local function through RPC
:param route: Name which function will be aliased for this function.
Remote side should call function by this name.
:param handler: Function or Route class (classes based on
:class:`wsrpc_aiohttp.WebSocketRoute`).
For route classes the public methods will
be registered automatically.
.. note::
Route classes might be initialized only once for the each
socket instance.
In case the method of class will be called first,
:func:`wsrpc_aiohttp.WebSocketRoute.init` will be called
without params before callable method.
"""
assert callable(handler) or isinstance(handler, WebSocketRoute)
cls.get_routes()[route] = handler
[docs] @classmethod
def remove_route(cls, route: str, fail=True):
""" Removes route by name. If `fail=True` an exception
will be raised in case the route was not found. """
if fail:
cls.get_routes().pop(route)
else:
cls.get_routes().pop(route, None)
def __repr__(self):
if hasattr(self, 'id'):
return "<RPCWebSocket: ID[{0}]>".format(self.id)
else:
return "<RPCWebsocket: {0} (waiting)>".format(self.__hash__())
@abc.abstractstaticmethod
async def _executor(self, func):
raise NotImplementedError
@property
def proxy(self):
""" Special property which allow run the remote functions
by `dot` notation
.. code-block:: python
# calls remote function with name ping
await client.proxy.ping()
# full equivalent of
await client.call('ping')
"""
return _Proxy(self.call)