# encoding: utf-8
import asyncio
import logging
import uuid
from collections import defaultdict
from functools import partial

import aiohttp
from aiohttp import WebSocketError, web
from import AbstractView

from .common import ClientException, WSRPCBase
from wsrpc_aiohttp.signal import Signal
from .tools import Lazy, awaitable, dumps

global_log = logging.getLogger("wsrpc")
log = logging.getLogger("wsrpc.handler")

[docs]class WebSocketBase(WSRPCBase, AbstractView): """ Base class for aiohttp websocket handler """ __slots__ = ( "_request", "socket", "id", "__pending_tasks", "__handlers", "store", "serial", "_ping", "protocol_version", ) KEEPALIVE_PING_TIMEOUT = 30 # type: int CLIENT_TIMEOUT = int(KEEPALIVE_PING_TIMEOUT / 3) # type: int MAX_CONCURRENT_REQUESTS = 25 # type: int REQUEST_EXECUTION_TIMEOUT = None # type: int ON_AUTH_SUCCESS = Signal() ON_AUTH_FAIL = Signal() ON_CONN_OPEN = Signal() ON_CONN_CLOSE = Signal() ON_CONN_FAIL = Signal() def __init__(self, request): AbstractView.__init__(self, request) WSRPCBase.__init__(self, timeout=self.REQUEST_EXECUTION_TIMEOUT) self._ping = defaultdict(self._loop.create_future) = uuid.uuid4() self.protocol_version = None self.serial = 0 self.semaphore = asyncio.Semaphore(self.MAX_CONCURRENT_REQUESTS)
[docs] @classmethod def configure( cls, keepalive_timeout=KEEPALIVE_PING_TIMEOUT, client_timeout=CLIENT_TIMEOUT, max_concurrent_requests=MAX_CONCURRENT_REQUESTS, ): """ Configures the handler class :param keepalive_timeout: sets timeout of client pong response :param client_timeout: internal lock timeout :param max_concurrent_requests: how many concurrent requests might be performed by each client """ cls.KEEPALIVE_PING_TIMEOUT = keepalive_timeout cls.CLIENT_TIMEOUT = client_timeout cls.MAX_CONCURRENT_REQUESTS = max_concurrent_requests
[docs] @classmethod def freeze(cls): """ Freeze all signals """ for signal in ( cls.ON_AUTH_SUCCESS, cls.ON_AUTH_FAIL, cls.ON_CONN_OPEN, cls.ON_CONN_CLOSE, cls.ON_CONN_FAIL, cls.ON_CALL_START, cls.ON_CALL_SUCCESS, cls.ON_CALL_FAIL, ): # type: Signal if not signal.is_frozen: signal.freeze()
def __await__(self): return self.__handle_request().__await__()
[docs] async def authorize(self) -> bool: """ Special method for authorize client. If this method return True then access allowed, otherwise ``403 Forbidden`` will be sent. This method will be called before socket connection establishment. By default everyone has access. You have to inherit this class and change this behaviour. .. note:: You can validate some headers (self.request.headers) or check cookies (self.reauest.cookies). """ return True
async def __handle_request(self): self.socket = web.WebSocketResponse() await, request=self.request) if not await self.authorize(): await socket=self.socket, request=self.request, ) raise web.HTTPForbidden() await socket=self.socket, request=self.request, ) try: await self.socket.prepare(self.request) except Exception as err: await socket=self.socket, request=self.request, err=err, ) raise try: self.clients[] = self self._create_task(self._start_ping()) async for msg in self.socket: try: await self._on_message(msg) except WebSocketError: log.error( "Client connection %s closed with exception %s",, self.socket.exception(), ) break else:"Client connection %s closed", return self.socket finally: await socket=self.socket, request=self.request, ) await self.close()
[docs] @classmethod def broadcast(cls, func, callback=None, return_exceptions=True, **kwargs): """ Call remote function on all connected clients :param func: Remote route name :param callback: Function which receive responses :param return_exceptions: Return exceptions of client calls instead of raise a first one """ tasks = [] for client in cls.get_clients().values(): task = asyncio.ensure_future(, **kwargs)) if callback: task.add_done_callback(partial(callback, client)) tasks.append(task) return asyncio.gather(*tasks, return_exceptions=return_exceptions)
async def _send(self, **kwargs): try: log.debug( "Sending message to %s serial %s: %s", Lazy(lambda: str(, Lazy(lambda: str(kwargs.get("id"))), Lazy(lambda: str(kwargs)), ) await self.socket.send_json(kwargs, dumps=lambda x: dumps(x)) except aiohttp.WebSocketError: self._create_task(self.close()) @staticmethod def _format_error(e): return {"type": str(type(e).__name__), "message": str(e)} def _reject(self, serial, error): future = self._futures.get(serial) if future: future.set_exception(ClientException(error))
[docs] async def close(self, message=None): """ Cancel all pending tasks and stop this socket connection """ await self.socket.close() await super().close() if in self.clients: self.clients.pop( for name, obj in self._handlers.items(): self._loop.create_task(awaitable(obj._onclose)())
def _log_client_list(self): log.debug( "CLIENTS: %s", Lazy( lambda: "".join(["\n\t%r" % i for i in self.clients.values()]) ), ) async def _start_ping(self): while True: if self.socket.closed: return future = asyncio.ensure_future("ping", seq=self._loop.time()) ) def on_timeout(): if future.done(): return if isinstance(future, asyncio.Task): future.cancel() return future.set_exception(TimeoutError) handle = self._loop.call_later( self.KEEPALIVE_PING_TIMEOUT, on_timeout ) future.add_done_callback(lambda f: handle.cancel()) try: resp = await future if not resp: continue delta = self._loop.time() - resp.get("seq", 0) log.debug("%r Pong recieved: %.4f" % (self, delta)) except asyncio.CancelledError: break except (TimeoutError, asyncio.TimeoutError): 'Client "%r" connection should be ' "closed because ping timeout", self, ) self._loop.create_task(self.close()) break except Exception: log.exception("Error when ping remote side.") break if delta > self.CLIENT_TIMEOUT: 'Client "%r" connection should be closed because ping ' "response time gather then client timeout", self, ) self._loop.create_task(self.close()) break await asyncio.sleep(self.KEEPALIVE_PING_TIMEOUT)
[docs]class WebSocketAsync(WebSocketBase): """ Handler class which execute any route as a coroutine """ async def _executor(self, func): return await awaitable(func)()
[docs]class WebSocketThreaded(WebSocketBase): """ Handler class which execute any route in the default thread-pool of current event loop """ async def _executor(self, func): return await self._loop.run_in_executor(None, func)
__all__ = ( "ClientException", "WebSocketAsync", "WebSocketBase", "WebSocketThreaded", )