diff --git a/duniterpy/api/bma/ws.py b/duniterpy/api/bma/ws.py index 5f2e40b9edddcee772131e228777813acaf65eb3..e4e1d021aff6e2a41661a2c0f0e8dcad8fca4671 100644 --- a/duniterpy/api/bma/ws.py +++ b/duniterpy/api/bma/ws.py @@ -18,10 +18,8 @@ # vit import logging -from aiohttp.client import _WSRequestContextManager - from duniterpy.api.bma.blockchain import BLOCK_SCHEMA -from duniterpy.api.client import Client +from duniterpy.api.client import Client, WSConnection logger = logging.getLogger("duniter/ws") @@ -42,21 +40,21 @@ WS_PEER_SCHEMA = { } -def block(client: Client) -> _WSRequestContextManager: +async def block(client: Client) -> WSConnection: """ Connect to block websocket :param client: Client to connect to the api :return: """ - return client.connect_ws(MODULE + "/block") + return await client.connect_ws(MODULE + "/block") -def peer(client: Client) -> _WSRequestContextManager: +async def peer(client: Client) -> WSConnection: """ Connect to peer websocket :param client: Client to connect to the api :return: """ - return client.connect_ws(MODULE + "/peer") + return await client.connect_ws(MODULE + "/peer") diff --git a/duniterpy/api/client.py b/duniterpy/api/client.py index 6dc1ce688708836c7ad582e52e9514a5b1897e80..4f2df1225dd4cae53d8dc40d4c3f8e7b463a7090 100644 --- a/duniterpy/api/client.py +++ b/duniterpy/api/client.py @@ -7,7 +7,7 @@ import logging from typing import Callable, Union, Any, Optional import jsonschema -from aiohttp import ClientResponse, ClientSession +from aiohttp import ClientResponse, ClientSession, ClientWebSocketResponse from aiohttp.client import _WSRequestContextManager import duniterpy.api.endpoint as endpoint from .errors import DuniterError @@ -19,6 +19,9 @@ RESPONSE_JSON = "json" RESPONSE_TEXT = "text" RESPONSE_AIOHTTP = "aiohttp" +# Connection type constants +CONNECTION_TYPE_AIOHTTP = 1 + # jsonschema validator ERROR_SCHEMA = { "type": "object", @@ -78,6 +81,88 @@ async def parse_response(response: ClientResponse, schema: dict) -> Any: raise jsonschema.ValidationError("Could not parse json : {0}".format(str(e))) +class WSConnection: + + # From the documentation of the aiohttp_library, the web socket connection + # + # await ws_connection = session.ws_connect() + # + # should return a ClientWebSocketResponse object... + # + # https://docs.aiohttp.org/en/stable/client_quickstart.html#websockets + # + # In fact, aiohttp.session.ws_connect() returns a aiohttp.client._WSRequestContextManager instance. + # It must be used in a with statement to get the ClientWebSocketResponse instance from it (__aenter__). + # At the end of the with statement, aiohttp.client._WSRequestContextManager.__aexit__ is called + # and close the ClientWebSocketResponse in it. + # + # await with ws_connection as ws: + # await ws.receive_str() + # + def __init__(self, connection: _WSRequestContextManager) -> None: + """ + Init WSConnection instance + + :param connection: Connection instance of the connection library + """ + if not isinstance(connection, _WSRequestContextManager): + raise Exception( + BaseException( + "Only aiohttp.client._WSRequestContextManager class supported" + ) + ) + + self.connection_type = CONNECTION_TYPE_AIOHTTP + self._connection = connection # type: _WSRequestContextManager + self.connection = None # type: Optional[ClientWebSocketResponse] + + async def send_str(self, data: str) -> None: + """ + Send a data string to the web socket connection + + :param data: Data string + :return: + """ + if self.connection is None: + raise Exception("Connection property is empty") + + await self.connection.send_str(data) + return None + + async def receive_str(self, timeout: Optional[float] = None) -> Optional[str]: + """ + Wait for a data string from the web socket connection + + :param timeout: Timeout in seconds + :return: + """ + if self.connection is None: + raise Exception("Connection property is empty") + + return await self.connection.receive_str(timeout=timeout) + + async def init_connection(self): + """ + Mandatory for aiohttp library to avoid the use of the with statement + + :return: + """ + self.connection = await self._connection.__aenter__() + + async def close(self) -> None: + """ + Close the web socket connection + + :return: + """ + await self._connection.__aexit__(None, None, None) + + if self.connection is None: + raise Exception("Connection property is empty") + + await self.connection.close() + + class API: """ API is a class used as an abstraction layer over the request library (AIOHTTP). @@ -170,7 +255,7 @@ class API: ) return response - def connect_ws(self, path: str) -> _WSRequestContextManager: + async def connect_ws(self, path: str) -> WSConnection: """ Connect to a websocket in order to use API parameters @@ -183,10 +268,18 @@ class API: :return: """ url = self.reverse_url(self.connection_handler.ws_scheme, path) - return self.connection_handler.session.ws_connect( - url, proxy=self.connection_handler.proxy + + connection = WSConnection( + self.connection_handler.session.ws_connect( + url, proxy=self.connection_handler.proxy, autoclose=False + ) ) + # init aiohttp connection + await connection.init_connection() + + return connection + class Client: """ @@ -301,7 +394,7 @@ class Client: return result - def connect_ws(self, path: str = "") -> _WSRequestContextManager: + async def connect_ws(self, path: str = "") -> WSConnection: """ Connect to a websocket in order to use API parameters @@ -309,7 +402,7 @@ class Client: :return: """ client = API(self.endpoint.conn_handler(self.session, self.proxy)) - return client.connect_ws(path) + return await client.connect_ws(path) async def close(self): """ diff --git a/examples/request_ws2p.py b/examples/request_ws2p.py index b77ef48c7d34c6e2eb49214ee52de1692c82fab9..00a83305e3c05d9807b44769d46c53ad8b7bb22c 100644 --- a/examples/request_ws2p.py +++ b/examples/request_ws2p.py @@ -50,238 +50,209 @@ async def main(): client = Client(WS2P_ENDPOINT) try: - # Create Web Socket connection on block path - ws_connection = client.connect_ws() + # Create a Web Socket connection + ws = await client.connect_ws() - # From the documentation ws_connection should be a ClientWebSocketResponse object... - # - # https://docs.aiohttp.org/en/stable/client_quickstart.html#websockets - # - # In reality, aiohttp.session.ws_connect() returns a aiohttp.client._WSRequestContextManager instance. - # It must be used in a with statement to get the ClientWebSocketResponse instance from it (__aenter__). - # At the end of the with statement, aiohttp.client._WSRequestContextManager.__aexit__ is called - # and close the ClientWebSocketResponse in it. + print("Connected successfully to web socket endpoint") - # Mandatory to get the "for msg in ws" to work ! - async with ws_connection as ws: - print("Connected successfully to web socket endpoint") + # START HANDSHAKE ####################################################### + print("\nSTART HANDSHAKE...") - # START HANDSHAKE ####################################################### - print("\nSTART HANDSHAKE...") + print("Send CONNECT message") + await ws.send_str(connect_message) - print("Send CONNECT message") - await ws.send_str(connect_message) + loop = True + # Iterate on each message received... + while loop: + print("ws.receive_str()") + msg = await ws.receive_str() - # Iterate on each message received... - async for msg in ws: # type: aiohttp.WSMessage + # Display incoming message from peer + print(msg) - # Display incoming message from peer - print(msg) + try: + # Validate json string with jsonschema and return a dict + data = parse_text(msg, ws2p.network.WS2P_CONNECT_MESSAGE_SCHEMA) + + except jsonschema.exceptions.ValidationError: + try: + # Validate json string with jsonschema and return a dict + data = parse_text(msg, ws2p.network.WS2P_ACK_MESSAGE_SCHEMA) - # If message type is text... - if msg.type == aiohttp.WSMsgType.TEXT: - # print(msg.data) + except jsonschema.exceptions.ValidationError: try: # Validate json string with jsonschema and return a dict - data = parse_text( - msg.data, ws2p.network.WS2P_CONNECT_MESSAGE_SCHEMA - ) + data = parse_text(msg, ws2p.network.WS2P_OK_MESSAGE_SCHEMA) except jsonschema.exceptions.ValidationError: - try: - # Validate json string with jsonschema and return a dict - data = parse_text( - msg.data, ws2p.network.WS2P_ACK_MESSAGE_SCHEMA - ) - - except jsonschema.exceptions.ValidationError: - try: - # Validate json string with jsonschema and return a dict - data = parse_text( - msg.data, ws2p.network.WS2P_OK_MESSAGE_SCHEMA - ) - - except jsonschema.exceptions.ValidationError: - continue - - print("Received a OK message") - - Ok( - CURRENCY, - remote_connect_document.pubkey, - connect_document.challenge, - data["sig"], - ) - print("Received OK message signature is valid") - - # END HANDSHAKE ####################################################### - print("END OF HANDSHAKE\n") - - # Uncomment the following command to stop listening for messages anymore - break - - # Uncomment the following commands to continue to listen incoming messages - # print("waiting for incoming messages...\n") - # continue - - print("Received a ACK message") - - # Create ACK document from ACK response to verify signature - Ack( - CURRENCY, - data["pub"], - connect_document.challenge, - data["sig"], - ) - print("Received ACK message signature is valid") - # If ACK response is ok, create OK message - ok_message = Ok( - CURRENCY, signing_key.pubkey, connect_document.challenge - ).get_signed_json(signing_key) - - # Send OK message - print("Send OK message...") - await ws.send_str(ok_message) continue - print("Received a CONNECT message") + print("Received a OK message") - remote_connect_document = Connect( - CURRENCY, data["pub"], data["challenge"], data["sig"] + Ok( + CURRENCY, + remote_connect_document.pubkey, + connect_document.challenge, + data["sig"], ) - print("Received CONNECT message signature is valid") - - ack_message = Ack( - CURRENCY, signing_key.pubkey, remote_connect_document.challenge - ).get_signed_json(signing_key) - # Send ACK message - print("Send ACK message...") - await ws.send_str(ack_message) - - elif msg.type == aiohttp.WSMsgType.CLOSED: - # Connection is closed - print("Web socket connection closed !") - elif msg.type == aiohttp.WSMsgType.ERROR: - # Connection error - print("Web socket connection error !") - - # Send ws2p request - print("Send getCurrent() request") - request_id = get_ws2p_challenge()[:8] - await ws.send_str(requests.get_current(request_id)) - - # Wait response with request id + print("Received OK message signature is valid") + + # END HANDSHAKE ####################################################### + print("END OF HANDSHAKE\n") + + # Uncomment the following command to stop listening for messages anymore + break + + # Uncomment the following commands to continue to listen incoming messages + # print("waiting for incoming messages...\n") + # continue + + print("Received a ACK message") + + # Create ACK document from ACK response to verify signature + Ack(CURRENCY, data["pub"], connect_document.challenge, data["sig"]) + print("Received ACK message signature is valid") + # If ACK response is ok, create OK message + ok_message = Ok( + CURRENCY, signing_key.pubkey, connect_document.challenge + ).get_signed_json(signing_key) + + # Send OK message + print("Send OK message...") + await ws.send_str(ok_message) + continue + + print("Received a CONNECT message") + + remote_connect_document = Connect( + CURRENCY, data["pub"], data["challenge"], data["sig"] + ) + print("Received CONNECT message signature is valid") + + ack_message = Ack( + CURRENCY, signing_key.pubkey, remote_connect_document.challenge + ).get_signed_json(signing_key) + # Send ACK message + print("Send ACK message...") + await ws.send_str(ack_message) + + # Send ws2p request + print("Send getCurrent() request") + request_id = get_ws2p_challenge()[:8] + await ws.send_str(requests.get_current(request_id)) + + # Wait response with request id + response_str = await ws.receive_str() + while "resId" not in json.loads(response_str) or ( + "resId" in json.loads(response_str) + and json.loads(response_str)["resId"] != request_id + ): response_str = await ws.receive_str() - while "resId" not in json.loads(response_str) or ( - "resId" in json.loads(response_str) - and json.loads(response_str)["resId"] != request_id - ): - response_str = await ws.receive_str() - time.sleep(1) + time.sleep(1) + try: + # Check response format + parse_text(response_str, requests.BLOCK_RESPONSE_SCHEMA) + # If valid display response + print("Response: " + response_str) + except ValidationError: + # If invalid response... try: - # Check response format - parse_text(response_str, requests.BLOCK_RESPONSE_SCHEMA) - # If valid display response - print("Response: " + response_str) - except ValidationError: - # If invalid response... - try: - # Check error response format - parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) - # If valid, display error response - print("Error response: " + response_str) - except ValidationError as exception: - # If invalid, display exception on response validation - print(exception) - - # Send ws2p request - print("Send getBlock(360000) request") - request_id = get_ws2p_challenge()[:8] - await ws.send_str(requests.get_block(request_id, 360000)) - - # Wait response with request id + # Check error response format + parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) + # If valid, display error response + print("Error response: " + response_str) + except ValidationError as exception: + # If invalid, display exception on response validation + print(exception) + + # Send ws2p request + print("Send getBlock(360000) request") + request_id = get_ws2p_challenge()[:8] + await ws.send_str(requests.get_block(request_id, 360000)) + + # Wait response with request id + response_str = await ws.receive_str() + while "resId" not in json.loads(response_str) or ( + "resId" in json.loads(response_str) + and json.loads(response_str)["resId"] != request_id + ): response_str = await ws.receive_str() - while "resId" not in json.loads(response_str) or ( - "resId" in json.loads(response_str) - and json.loads(response_str)["resId"] != request_id - ): - response_str = await ws.receive_str() - time.sleep(1) + time.sleep(1) + try: + # Check response format + parse_text(response_str, requests.BLOCK_RESPONSE_SCHEMA) + # If valid display response + print("Response: " + response_str) + except ValidationError: + # If invalid response... try: - # Check response format - parse_text(response_str, requests.BLOCK_RESPONSE_SCHEMA) - # If valid display response - print("Response: " + response_str) - except ValidationError: - # If invalid response... - try: - # Check error response format - parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) - # If valid, display error response - print("Error response: " + response_str) - except ValidationError as exception: - # If invalid, display exception on response validation - print(exception) - - # Send ws2p request - print("Send getBlocks(360000, 2) request") - request_id = get_ws2p_challenge()[:8] - await ws.send_str(requests.get_blocks(request_id, 360000, 2)) - - # Wait response with request id + # Check error response format + parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) + # If valid, display error response + print("Error response: " + response_str) + except ValidationError as exception: + # If invalid, display exception on response validation + print(exception) + + # Send ws2p request + print("Send getBlocks(360000, 2) request") + request_id = get_ws2p_challenge()[:8] + await ws.send_str(requests.get_blocks(request_id, 360000, 2)) + + # Wait response with request id + response_str = await ws.receive_str() + while "resId" not in json.loads(response_str) or ( + "resId" in json.loads(response_str) + and json.loads(response_str)["resId"] != request_id + ): response_str = await ws.receive_str() - while "resId" not in json.loads(response_str) or ( - "resId" in json.loads(response_str) - and json.loads(response_str)["resId"] != request_id - ): - response_str = await ws.receive_str() - time.sleep(1) + time.sleep(1) + try: + # Check response format + parse_text(response_str, requests.BLOCKS_RESPONSE_SCHEMA) + # If valid display response + print("Response: " + response_str) + except ValidationError: + # If invalid response... try: - # Check response format - parse_text(response_str, requests.BLOCKS_RESPONSE_SCHEMA) - # If valid display response - print("Response: " + response_str) - except ValidationError: - # If invalid response... - try: - # Check error response format - parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) - # If valid, display error response - print("Error response: " + response_str) - except ValidationError as exception: - # If invalid, display exception on response validation - print(exception) - - # Send ws2p request - print("Send getRequirementsPending(3) request") - request_id = get_ws2p_challenge()[:8] - await ws.send_str(requests.get_requirements_pending(request_id, 3)) - # Wait response with request id + # Check error response format + parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) + # If valid, display error response + print("Error response: " + response_str) + except ValidationError as exception: + # If invalid, display exception on response validation + print(exception) + + # Send ws2p request + print("Send getRequirementsPending(3) request") + request_id = get_ws2p_challenge()[:8] + await ws.send_str(requests.get_requirements_pending(request_id, 3)) + # Wait response with request id + response_str = await ws.receive_str() + while "resId" not in json.loads(response_str) or ( + "resId" in json.loads(response_str) + and json.loads(response_str)["resId"] != request_id + ): response_str = await ws.receive_str() - while "resId" not in json.loads(response_str) or ( - "resId" in json.loads(response_str) - and json.loads(response_str)["resId"] != request_id - ): - response_str = await ws.receive_str() - time.sleep(1) + time.sleep(1) + try: + # Check response format + parse_text(response_str, requests.REQUIREMENTS_RESPONSE_SCHEMA) + # If valid display response + print("Response: " + response_str) + except ValidationError: + # If invalid response... try: - # Check response format - parse_text(response_str, requests.REQUIREMENTS_RESPONSE_SCHEMA) - # If valid display response - print("Response: " + response_str) - except ValidationError: - # If invalid response... - try: - # Check error response format - parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) - # If valid, display error response - print("Error response: " + response_str) - except ValidationError as exception: - # If invalid, display exception on response validation - print(exception) - - # Close session - await client.close() + # Check error response format + parse_text(response_str, requests.ERROR_RESPONSE_SCHEMA) + # If valid, display error response + print("Error response: " + response_str) + except ValidationError as exception: + # If invalid, display exception on response validation + print(exception) + + # Close session + await client.close() except (aiohttp.WSServerHandshakeError, ValueError) as e: print("Websocket handshake {0} : {1}".format(type(e).__name__, str(e)))