From a4c6d7da871a86585554b2820fae72eb131704fe Mon Sep 17 00:00:00 2001
From: Moul <moul@moul.re>
Date: Mon, 5 Jun 2023 22:59:53 +0200
Subject: [PATCH] Allow to pass any context object to define_click_context()

network: Assign if key exists
auth: use helpers.define_click_context()
---
 silkaj/auth.py             | 14 +++++++-------
 silkaj/network.py          |  4 ++--
 tests/helpers.py           |  6 +++---
 tests/unit/test_auth.py    | 23 +++++++++--------------
 tests/unit/test_network.py |  2 +-
 5 files changed, 22 insertions(+), 27 deletions(-)

diff --git a/silkaj/auth.py b/silkaj/auth.py
index fa669224..f488274d 100644
--- a/silkaj/auth.py
+++ b/silkaj/auth.py
@@ -32,11 +32,11 @@ PUBSEC_SIGNKEY_PATTERN = "sec: ([1-9A-HJ-NP-Za-km-z]{87,90})"
 
 @pass_context
 def auth_method(ctx: Context) -> SigningKey:
-    if ctx.obj["AUTH_SEED"]:
+    if "AUTH_SEED" in ctx.obj and ctx.obj["AUTH_SEED"]:
         return auth_by_seed()
-    if ctx.obj["AUTH_FILE"]:
+    if "AUTH_FILE" in ctx.obj and ctx.obj["AUTH_FILE"]:
         return auth_by_auth_file()
-    if ctx.obj["AUTH_WIF"]:
+    if "AUTH_WIF" in ctx.obj and ctx.obj["AUTH_WIF"]:
         return auth_by_wif()
     return auth_by_scrypt()
 
@@ -44,10 +44,10 @@ def auth_method(ctx: Context) -> SigningKey:
 @pass_context
 def has_auth_method(ctx: Context) -> bool:
     return (
-        ctx.obj["AUTH_SCRYPT"]
-        or ctx.obj["AUTH_FILE"]
-        or ctx.obj["AUTH_SEED"]
-        or ctx.obj["AUTH_WIF"]
+        ("AUTH_SCRYPT" in ctx.obj and ctx.obj["AUTH_SCRYPT"])
+        or ("AUTH_FILE" in ctx.obj and ctx.obj["AUTH_FILE"])
+        or ("AUTH_SEED" in ctx.obj and ctx.obj["AUTH_SEED"])
+        or ("AUTH_WIF" in ctx.obj and ctx.obj["AUTH_WIF"])
     )
 
 
diff --git a/silkaj/network.py b/silkaj/network.py
index 83425022..93d9c6c4 100644
--- a/silkaj/network.py
+++ b/silkaj/network.py
@@ -44,8 +44,8 @@ def determine_endpoint() -> ep.Endpoint:
         from click.globals import get_current_context
 
         ctx = get_current_context()
-        endpoint = ctx.obj["ENDPOINT"]
-        gtest = ctx.obj["GTEST"]
+        endpoint = ctx.obj["ENDPOINT"] if "ENDPOINT" in ctx.obj else None
+        gtest = ctx.obj["GTEST"] if "GTEST" in ctx.obj else None
     except (ModuleNotFoundError, RuntimeError):
         endpoint, gtest = None, None
 
diff --git a/tests/helpers.py b/tests/helpers.py
index 019ced8e..8558daa0 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -18,9 +18,9 @@ import click
 from silkaj import cli
 
 
-def define_click_context(endpoint=None, gtest=False):
+def define_click_context(**kwargs):
     ctx = click.Context(cli.cli)
     ctx.obj = {}
-    ctx.obj["ENDPOINT"] = endpoint
-    ctx.obj["GTEST"] = gtest
+    for kwarg in kwargs.items():
+        ctx.obj[kwarg[0].upper()] = kwarg[1]
     click.globals.push_context(ctx)
diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py
index 74b1e532..3b5b6008 100644
--- a/tests/unit/test_auth.py
+++ b/tests/unit/test_auth.py
@@ -13,10 +13,10 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with Silkaj. If not, see <https://www.gnu.org/licenses/>.
 
-import click
 import pytest
 
 from silkaj import auth
+from tests import helpers
 from tests.patched.auth import (
     patched_auth_by_auth_file,
     patched_auth_by_scrypt,
@@ -25,24 +25,19 @@ from tests.patched.auth import (
 )
 
 
-# test auth_method
 @pytest.mark.parametrize(
-    ("seed", "file", "wif", "auth_method_called"),
+    ("context", "auth_method_called"),
     [
-        (True, False, False, "call_auth_by_seed"),
-        (False, True, False, "call_auth_by_auth_file"),
-        (False, False, True, "call_auth_by_wif"),
-        (False, False, False, "call_auth_by_scrypt"),
+        ({"auth_seed": True}, "call_auth_by_seed"),
+        ({"auth_file_path": True}, "call_auth_by_auth_file"),
+        ({"auth_wif": True}, "call_auth_by_wif"),
+        ({}, "call_auth_by_scrypt"),
     ],
 )
-def test_auth_method(seed, file, wif, auth_method_called, monkeypatch):
+def test_auth_method(context, auth_method_called, monkeypatch):
     monkeypatch.setattr(auth, "auth_by_seed", patched_auth_by_seed)
     monkeypatch.setattr(auth, "auth_by_wif", patched_auth_by_wif)
     monkeypatch.setattr(auth, "auth_by_auth_file", patched_auth_by_auth_file)
     monkeypatch.setattr(auth, "auth_by_scrypt", patched_auth_by_scrypt)
-    ctx = click.Context(
-        click.Command(""),
-        obj={"AUTH_SEED": seed, "AUTH_FILE": file, "AUTH_WIF": wif},
-    )
-    with ctx:
-        assert auth_method_called == auth.auth_method()
+    helpers.define_click_context(**context)
+    assert auth_method_called == auth.auth_method()
diff --git a/tests/unit/test_network.py b/tests/unit/test_network.py
index 531ff3f6..22b919a5 100644
--- a/tests/unit/test_network.py
+++ b/tests/unit/test_network.py
@@ -45,7 +45,7 @@ IPV6 = "2001:0db8:85a3:0000:0000:8a2e:0370:7334"
     ],
 )
 def test_determine_endpoint_custom(endpoint, host, ipv4, ipv6, port, path):
-    helpers.define_click_context(endpoint)
+    helpers.define_click_context(endpoint=endpoint)
     ep = network.determine_endpoint()
     assert ep.host == host
     assert ep.ipv4 == ipv4
-- 
GitLab