diff --git a/src/sakia/tests/unit/tools/test_decorators.py b/src/sakia/tests/unit/tools/test_decorators.py index 60a5eb33c64ddce10f015a2d7ac6ef6ecb169d30..89b373908201fd87f0199e31cec4d038af28fb8a 100644 --- a/src/sakia/tests/unit/tools/test_decorators.py +++ b/src/sakia/tests/unit/tools/test_decorators.py @@ -112,3 +112,44 @@ class TestDecorators(unittest.TestCase, QuamashTest): self.assertEqual(calls["B"], 1) self.assertEqual(calls["C"], 0) self.assertEqual(calls["D"], 1) + + def test_two_runners(self): + class TaskRunner: + def __init__(self, name): + self.some_long_task(name, incrementer) + + @classmethod + def create(cls, name): + return cls(name) + + @once_at_a_time + @asyncify + async def some_long_task(self, name, callback): + await asyncio.sleep(1) + callback(name) + await asyncio.sleep(1) + callback(name) + + def cancel_long_task(self): + cancel_once_task(self, self.some_long_task) + + calls = {'A': 0, 'B': 0, 'C': 0} + + def incrementer(name): + nonlocal calls + calls[name] += 1 + + async def exec_test(): + tr1 = TaskRunner.create("A") + tr2 = TaskRunner.create("B") + tr3 = TaskRunner.create("C") + await asyncio.sleep(1.5) + tr1.some_long_task("A", incrementer) + tr2.some_long_task("B", incrementer) + tr3.some_long_task("C", incrementer) + await asyncio.sleep(1.5) + + self.lp.run_until_complete(exec_test()) + self.assertEqual(calls["A"], 2) + self.assertEqual(calls["B"], 2) + self.assertEqual(calls["C"], 2) diff --git a/src/sakia/tools/decorators.py b/src/sakia/tools/decorators.py index 7648502ebeb53d64180a81e1c5b04e40aa330538..f9dc5f2103c04299d0d010141f8ab759f13d8c5e 100644 --- a/src/sakia/tools/decorators.py +++ b/src/sakia/tools/decorators.py @@ -19,24 +19,24 @@ def once_at_a_time(fn): func_call = args[0].__tasks[fn.__name__] args[0].__tasks.pop(fn.__name__) if getattr(func_call, "_next_task", None): - func_call._next_task._start() + start_task(func_call._next_task[0], + *func_call._next_task[1], + **func_call._next_task[2]) except KeyError: logging.debug("Task {0} already removed".format(fn.__name__)) - def start_task(): - args[0].__tasks[fn.__name__] = fn(*args, **kwargs) - args[0].__tasks[fn.__name__].add_done_callback(task_done) + def start_task(f, *a, **k): + args[0].__tasks[f.__name__] = f(*a, **k) + args[0].__tasks[f.__name__].add_done_callback(task_done) if getattr(args[0], "__tasks", None) is None: setattr(args[0], "__tasks", {}) - fn._start = lambda: start_task() - if fn.__name__ in args[0].__tasks: - args[0].__tasks[fn.__name__]._next_task = fn + args[0].__tasks[fn.__name__]._next_task = (fn, args, kwargs) args[0].__tasks[fn.__name__].cancel() else: - fn._start() + start_task(fn, *args, **kwargs) return args[0].__tasks[fn.__name__] return wrapper