Skip to content
Snippets Groups Projects
Commit 89067c11 authored by Vincent Texier's avatar Vincent Texier
Browse files

[enh] #58 refactor client.connect_ws() with a new WSConnection class

To simplify the use of the aiohttp web socket connection, use an abstraction layer with our own WSConnection class
parent 5fb4abbe
No related branches found
No related tags found
No related merge requests found
...@@ -18,10 +18,8 @@ ...@@ -18,10 +18,8 @@
# vit # vit
import logging import logging
from aiohttp.client import _WSRequestContextManager
from duniterpy.api.bma.blockchain import BLOCK_SCHEMA from duniterpy.api.bma.blockchain import BLOCK_SCHEMA
from duniterpy.api.client import Client from duniterpy.api.client import Client, WSConnection
logger = logging.getLogger("duniter/ws") logger = logging.getLogger("duniter/ws")
...@@ -42,21 +40,21 @@ WS_PEER_SCHEMA = { ...@@ -42,21 +40,21 @@ WS_PEER_SCHEMA = {
} }
def block(client: Client) -> _WSRequestContextManager: async def block(client: Client) -> WSConnection:
""" """
Connect to block websocket Connect to block websocket
:param client: Client to connect to the api :param client: Client to connect to the api
:return: :return:
""" """
return client.connect_ws(MODULE + "/block") return await client.connect_ws(MODULE + "/block")
def peer(client: Client) -> _WSRequestContextManager: async def peer(client: Client) -> WSConnection:
""" """
Connect to peer websocket Connect to peer websocket
:param client: Client to connect to the api :param client: Client to connect to the api
:return: :return:
""" """
return client.connect_ws(MODULE + "/peer") return await client.connect_ws(MODULE + "/peer")
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
from typing import Callable, Union, Any, Optional from typing import Callable, Union, Any, Optional
import jsonschema import jsonschema
from aiohttp import ClientResponse, ClientSession from aiohttp import ClientResponse, ClientSession, ClientWebSocketResponse
from aiohttp.client import _WSRequestContextManager from aiohttp.client import _WSRequestContextManager
import duniterpy.api.endpoint as endpoint import duniterpy.api.endpoint as endpoint
from .errors import DuniterError from .errors import DuniterError
...@@ -19,6 +19,9 @@ RESPONSE_JSON = "json" ...@@ -19,6 +19,9 @@ RESPONSE_JSON = "json"
RESPONSE_TEXT = "text" RESPONSE_TEXT = "text"
RESPONSE_AIOHTTP = "aiohttp" RESPONSE_AIOHTTP = "aiohttp"
# Connection type constants
CONNECTION_TYPE_AIOHTTP = 1
# jsonschema validator # jsonschema validator
ERROR_SCHEMA = { ERROR_SCHEMA = {
"type": "object", "type": "object",
...@@ -78,6 +81,88 @@ async def parse_response(response: ClientResponse, schema: dict) -> Any: ...@@ -78,6 +81,88 @@ async def parse_response(response: ClientResponse, schema: dict) -> Any:
raise jsonschema.ValidationError("Could not parse json : {0}".format(str(e))) raise jsonschema.ValidationError("Could not parse json : {0}".format(str(e)))
class WSConnection:
# From the documentation of the aiohttp_library, the web socket connection
#
# await ws_connection = session.ws_connect()
#
# should return a ClientWebSocketResponse object...
#
# https://docs.aiohttp.org/en/stable/client_quickstart.html#websockets
#
# In fact, aiohttp.session.ws_connect() returns a aiohttp.client._WSRequestContextManager instance.
# It must be used in a with statement to get the ClientWebSocketResponse instance from it (__aenter__).
# At the end of the with statement, aiohttp.client._WSRequestContextManager.__aexit__ is called
# and close the ClientWebSocketResponse in it.
#
# await with ws_connection as ws:
# await ws.receive_str()
#
def __init__(self, connection: _WSRequestContextManager) -> None:
"""
Init WSConnection instance
:param connection: Connection instance of the connection library
"""
if not isinstance(connection, _WSRequestContextManager):
raise Exception(
BaseException(
"Only aiohttp.client._WSRequestContextManager class supported"
)
)
self.connection_type = CONNECTION_TYPE_AIOHTTP
self._connection = connection # type: _WSRequestContextManager
self.connection = None # type: Optional[ClientWebSocketResponse]
async def send_str(self, data: str) -> None:
"""
Send a data string to the web socket connection
:param data: Data string
:return:
"""
if self.connection is None:
raise Exception("Connection property is empty")
await self.connection.send_str(data)
return None
async def receive_str(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for a data string from the web socket connection
:param timeout: Timeout in seconds
:return:
"""
if self.connection is None:
raise Exception("Connection property is empty")
return await self.connection.receive_str(timeout=timeout)
async def init_connection(self):
"""
Mandatory for aiohttp library to avoid the use of the with statement
:return:
"""
self.connection = await self._connection.__aenter__()
async def close(self) -> None:
"""
Close the web socket connection
:return:
"""
await self._connection.__aexit__(None, None, None)
if self.connection is None:
raise Exception("Connection property is empty")
await self.connection.close()
class API: class API:
""" """
API is a class used as an abstraction layer over the request library (AIOHTTP). API is a class used as an abstraction layer over the request library (AIOHTTP).
...@@ -170,7 +255,7 @@ class API: ...@@ -170,7 +255,7 @@ class API:
) )
return response return response
def connect_ws(self, path: str) -> _WSRequestContextManager: async def connect_ws(self, path: str) -> WSConnection:
""" """
Connect to a websocket in order to use API parameters Connect to a websocket in order to use API parameters
...@@ -183,10 +268,18 @@ class API: ...@@ -183,10 +268,18 @@ class API:
:return: :return:
""" """
url = self.reverse_url(self.connection_handler.ws_scheme, path) url = self.reverse_url(self.connection_handler.ws_scheme, path)
return self.connection_handler.session.ws_connect(
url, proxy=self.connection_handler.proxy connection = WSConnection(
self.connection_handler.session.ws_connect(
url, proxy=self.connection_handler.proxy, autoclose=False
)
) )
# init aiohttp connection
await connection.init_connection()
return connection
class Client: class Client:
""" """
...@@ -301,7 +394,7 @@ class Client: ...@@ -301,7 +394,7 @@ class Client:
return result return result
def connect_ws(self, path: str = "") -> _WSRequestContextManager: async def connect_ws(self, path: str = "") -> WSConnection:
""" """
Connect to a websocket in order to use API parameters Connect to a websocket in order to use API parameters
...@@ -309,7 +402,7 @@ class Client: ...@@ -309,7 +402,7 @@ class Client:
:return: :return:
""" """
client = API(self.endpoint.conn_handler(self.session, self.proxy)) client = API(self.endpoint.conn_handler(self.session, self.proxy))
return client.connect_ws(path) return await client.connect_ws(path)
async def close(self): async def close(self):
""" """
......
...@@ -50,20 +50,9 @@ async def main(): ...@@ -50,20 +50,9 @@ async def main():
client = Client(WS2P_ENDPOINT) client = Client(WS2P_ENDPOINT)
try: try:
# Create Web Socket connection on block path # Create a Web Socket connection
ws_connection = client.connect_ws() ws = await client.connect_ws()
# From the documentation ws_connection should be a ClientWebSocketResponse object...
#
# https://docs.aiohttp.org/en/stable/client_quickstart.html#websockets
#
# In reality, aiohttp.session.ws_connect() returns a aiohttp.client._WSRequestContextManager instance.
# It must be used in a with statement to get the ClientWebSocketResponse instance from it (__aenter__).
# At the end of the with statement, aiohttp.client._WSRequestContextManager.__aexit__ is called
# and close the ClientWebSocketResponse in it.
# Mandatory to get the "for msg in ws" to work !
async with ws_connection as ws:
print("Connected successfully to web socket endpoint") print("Connected successfully to web socket endpoint")
# START HANDSHAKE ####################################################### # START HANDSHAKE #######################################################
...@@ -72,34 +61,28 @@ async def main(): ...@@ -72,34 +61,28 @@ async def main():
print("Send CONNECT message") print("Send CONNECT message")
await ws.send_str(connect_message) await ws.send_str(connect_message)
loop = True
# Iterate on each message received... # Iterate on each message received...
async for msg in ws: # type: aiohttp.WSMessage while loop:
print("ws.receive_str()")
msg = await ws.receive_str()
# Display incoming message from peer # Display incoming message from peer
print(msg) print(msg)
# If message type is text...
if msg.type == aiohttp.WSMsgType.TEXT:
# print(msg.data)
try: try:
# Validate json string with jsonschema and return a dict # Validate json string with jsonschema and return a dict
data = parse_text( data = parse_text(msg, ws2p.network.WS2P_CONNECT_MESSAGE_SCHEMA)
msg.data, ws2p.network.WS2P_CONNECT_MESSAGE_SCHEMA
)
except jsonschema.exceptions.ValidationError: except jsonschema.exceptions.ValidationError:
try: try:
# Validate json string with jsonschema and return a dict # Validate json string with jsonschema and return a dict
data = parse_text( data = parse_text(msg, ws2p.network.WS2P_ACK_MESSAGE_SCHEMA)
msg.data, ws2p.network.WS2P_ACK_MESSAGE_SCHEMA
)
except jsonschema.exceptions.ValidationError: except jsonschema.exceptions.ValidationError:
try: try:
# Validate json string with jsonschema and return a dict # Validate json string with jsonschema and return a dict
data = parse_text( data = parse_text(msg, ws2p.network.WS2P_OK_MESSAGE_SCHEMA)
msg.data, ws2p.network.WS2P_OK_MESSAGE_SCHEMA
)
except jsonschema.exceptions.ValidationError: except jsonschema.exceptions.ValidationError:
continue continue
...@@ -127,12 +110,7 @@ async def main(): ...@@ -127,12 +110,7 @@ async def main():
print("Received a ACK message") print("Received a ACK message")
# Create ACK document from ACK response to verify signature # Create ACK document from ACK response to verify signature
Ack( Ack(CURRENCY, data["pub"], connect_document.challenge, data["sig"])
CURRENCY,
data["pub"],
connect_document.challenge,
data["sig"],
)
print("Received ACK message signature is valid") print("Received ACK message signature is valid")
# If ACK response is ok, create OK message # If ACK response is ok, create OK message
ok_message = Ok( ok_message = Ok(
...@@ -158,13 +136,6 @@ async def main(): ...@@ -158,13 +136,6 @@ async def main():
print("Send ACK message...") print("Send ACK message...")
await ws.send_str(ack_message) await ws.send_str(ack_message)
elif msg.type == aiohttp.WSMsgType.CLOSED:
# Connection is closed
print("Web socket connection closed !")
elif msg.type == aiohttp.WSMsgType.ERROR:
# Connection error
print("Web socket connection error !")
# Send ws2p request # Send ws2p request
print("Send getCurrent() request") print("Send getCurrent() request")
request_id = get_ws2p_challenge()[:8] request_id = get_ws2p_challenge()[:8]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment