From a4c6d7da871a86585554b2820fae72eb131704fe Mon Sep 17 00:00:00 2001 From: Moul <moul@moul.re> Date: Mon, 5 Jun 2023 22:59:53 +0200 Subject: [PATCH] Allow to pass any context object to define_click_context() network: Assign if key exists auth: use helpers.define_click_context() --- silkaj/auth.py | 14 +++++++------- silkaj/network.py | 4 ++-- tests/helpers.py | 6 +++--- tests/unit/test_auth.py | 23 +++++++++-------------- tests/unit/test_network.py | 2 +- 5 files changed, 22 insertions(+), 27 deletions(-) diff --git a/silkaj/auth.py b/silkaj/auth.py index fa669224..f488274d 100644 --- a/silkaj/auth.py +++ b/silkaj/auth.py @@ -32,11 +32,11 @@ PUBSEC_SIGNKEY_PATTERN = "sec: ([1-9A-HJ-NP-Za-km-z]{87,90})" @pass_context def auth_method(ctx: Context) -> SigningKey: - if ctx.obj["AUTH_SEED"]: + if "AUTH_SEED" in ctx.obj and ctx.obj["AUTH_SEED"]: return auth_by_seed() - if ctx.obj["AUTH_FILE"]: + if "AUTH_FILE" in ctx.obj and ctx.obj["AUTH_FILE"]: return auth_by_auth_file() - if ctx.obj["AUTH_WIF"]: + if "AUTH_WIF" in ctx.obj and ctx.obj["AUTH_WIF"]: return auth_by_wif() return auth_by_scrypt() @@ -44,10 +44,10 @@ def auth_method(ctx: Context) -> SigningKey: @pass_context def has_auth_method(ctx: Context) -> bool: return ( - ctx.obj["AUTH_SCRYPT"] - or ctx.obj["AUTH_FILE"] - or ctx.obj["AUTH_SEED"] - or ctx.obj["AUTH_WIF"] + ("AUTH_SCRYPT" in ctx.obj and ctx.obj["AUTH_SCRYPT"]) + or ("AUTH_FILE" in ctx.obj and ctx.obj["AUTH_FILE"]) + or ("AUTH_SEED" in ctx.obj and ctx.obj["AUTH_SEED"]) + or ("AUTH_WIF" in ctx.obj and ctx.obj["AUTH_WIF"]) ) diff --git a/silkaj/network.py b/silkaj/network.py index 83425022..93d9c6c4 100644 --- a/silkaj/network.py +++ b/silkaj/network.py @@ -44,8 +44,8 @@ def determine_endpoint() -> ep.Endpoint: from click.globals import get_current_context ctx = get_current_context() - endpoint = ctx.obj["ENDPOINT"] - gtest = ctx.obj["GTEST"] + endpoint = ctx.obj["ENDPOINT"] if "ENDPOINT" in ctx.obj else None + gtest = ctx.obj["GTEST"] if "GTEST" in ctx.obj else None except (ModuleNotFoundError, RuntimeError): endpoint, gtest = None, None diff --git a/tests/helpers.py b/tests/helpers.py index 019ced8e..8558daa0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -18,9 +18,9 @@ import click from silkaj import cli -def define_click_context(endpoint=None, gtest=False): +def define_click_context(**kwargs): ctx = click.Context(cli.cli) ctx.obj = {} - ctx.obj["ENDPOINT"] = endpoint - ctx.obj["GTEST"] = gtest + for kwarg in kwargs.items(): + ctx.obj[kwarg[0].upper()] = kwarg[1] click.globals.push_context(ctx) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 74b1e532..3b5b6008 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -13,10 +13,10 @@ # You should have received a copy of the GNU Affero General Public License # along with Silkaj. If not, see <https://www.gnu.org/licenses/>. -import click import pytest from silkaj import auth +from tests import helpers from tests.patched.auth import ( patched_auth_by_auth_file, patched_auth_by_scrypt, @@ -25,24 +25,19 @@ from tests.patched.auth import ( ) -# test auth_method @pytest.mark.parametrize( - ("seed", "file", "wif", "auth_method_called"), + ("context", "auth_method_called"), [ - (True, False, False, "call_auth_by_seed"), - (False, True, False, "call_auth_by_auth_file"), - (False, False, True, "call_auth_by_wif"), - (False, False, False, "call_auth_by_scrypt"), + ({"auth_seed": True}, "call_auth_by_seed"), + ({"auth_file_path": True}, "call_auth_by_auth_file"), + ({"auth_wif": True}, "call_auth_by_wif"), + ({}, "call_auth_by_scrypt"), ], ) -def test_auth_method(seed, file, wif, auth_method_called, monkeypatch): +def test_auth_method(context, auth_method_called, monkeypatch): monkeypatch.setattr(auth, "auth_by_seed", patched_auth_by_seed) monkeypatch.setattr(auth, "auth_by_wif", patched_auth_by_wif) monkeypatch.setattr(auth, "auth_by_auth_file", patched_auth_by_auth_file) monkeypatch.setattr(auth, "auth_by_scrypt", patched_auth_by_scrypt) - ctx = click.Context( - click.Command(""), - obj={"AUTH_SEED": seed, "AUTH_FILE": file, "AUTH_WIF": wif}, - ) - with ctx: - assert auth_method_called == auth.auth_method() + helpers.define_click_context(**context) + assert auth_method_called == auth.auth_method() diff --git a/tests/unit/test_network.py b/tests/unit/test_network.py index 531ff3f6..22b919a5 100644 --- a/tests/unit/test_network.py +++ b/tests/unit/test_network.py @@ -45,7 +45,7 @@ IPV6 = "2001:0db8:85a3:0000:0000:8a2e:0370:7334" ], ) def test_determine_endpoint_custom(endpoint, host, ipv4, ipv6, port, path): - helpers.define_click_context(endpoint) + helpers.define_click_context(endpoint=endpoint) ep = network.determine_endpoint() assert ep.host == host assert ep.ipv4 == ipv4 -- GitLab