Skip to content

Commit d37d8bc

Browse files
committed
base: Improve @shc.handler() to allow defining handler functions with less parameters
This commit includes tests and documentation for this new feature. We also needed to fix our AsyncMock implementation to work well with inspect.signature() in the tests. Fixes mhthies#63
1 parent 7c338b2 commit d37d8bc

File tree

4 files changed

+112
-6
lines changed

4 files changed

+112
-6
lines changed

docs/base.rst

+22-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ Putting it all together, a logic handler may look as follows::
181181
@timer.trigger
182182
@some_variable.trigger
183183
@shc.handler()
184-
async def my_logics(_value, origin):
184+
async def my_logics(_value, _origin):
185185
""" Write value of `some_variable` to KNX bus every 5 minutes & when it changes, but only for values > 42 """
186186
# We cannot use the value provided, since it is not defined when triggered by the timer
187187
value = await some_variable.read()
@@ -231,6 +231,27 @@ Putting it all together, a logic handler may look as follows::
231231
# Unfortunately, no .write() or .read() possible here.
232232

233233

234+
.. tip::
235+
236+
The :func:`shc.handler` and :func:`shc.blocking_handler` decorators take care of calling the logic handler function with the correct number of arguments:
237+
If you don't need the ``origin`` list, you can simply omit the second parameter of your wrapped logic handler function::
238+
239+
@shc.handler()
240+
async def my_value_only_handler(value):
241+
await some_variable.write(value + 3)
242+
243+
If you don't need the ``value`` either, you can also omit this parameter.
244+
Hence, the logic handler from the first example above can be rewritten as::
245+
246+
@timer.trigger
247+
@some_variable.trigger
248+
@shc.handler()
249+
async def my_logics():
250+
value = await some_variable.read()
251+
if value > 42:
252+
await some_knx_object.write(value)
253+
254+
234255
``shc.base`` Module Reference
235256
-----------------------------
236257

shc/base.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import asyncio
1414
import contextvars
1515
import functools
16+
import inspect
1617
import logging
1718
import random
1819
from typing import Generic, List, Any, Tuple, Callable, Optional, Type, TypeVar, Awaitable, Union, Dict, Set
@@ -433,7 +434,12 @@ async def _from_provider(self) -> Optional[T_con]:
433434
return convert(val) if convert else val
434435

435436

436-
def handler(reset_origin=False, allow_recursion=False) -> Callable[[LogicHandler], LogicHandler]:
437+
LogicHandlerOptionalParams = Union[LogicHandler,
438+
Callable[[T], Awaitable[None]],
439+
Callable[[], Awaitable[None]]]
440+
441+
442+
def handler(reset_origin=False, allow_recursion=False) -> Callable[[LogicHandlerOptionalParams], LogicHandler]:
437443
"""
438444
Decorator for custom logic handler functions.
439445
@@ -445,6 +451,10 @@ def handler(reset_origin=False, allow_recursion=False) -> Callable[[LogicHandler
445451
* the `origin` can magically be passed when called directly by other logic handlers
446452
* the execution is skipped when called recursively (i.e. the logic handler is already contained in the `origin` list
447453
454+
It also allows to define the logic handler function with different numbers of parameters: If the function takes two
455+
parameters, the trigger *value* and the *origin* are passed. If the function takes one parameter, only the *value*
456+
is passed. If the function takes no parameters, it is called without arguments.
457+
448458
:param reset_origin: If True, the origin which is magically passed to all `write` calls, only contains the logic
449459
handler itself, not the previous `origin` list, which led to the handler's execution. This can be used to
450460
change an object's value, which triggered this logic handler. This may cause infinite recursive feedback loops,
@@ -453,7 +463,9 @@ def handler(reset_origin=False, allow_recursion=False) -> Callable[[LogicHandler
453463
passed values and/or the `origin` list itself to prevent infinite feedback loops via `write` calls or calls to
454464
other logic handlers – especiaally when used together with `reset_origin`.
455465
"""
456-
def decorator(f: LogicHandler) -> LogicHandler:
466+
def decorator(f: LogicHandlerOptionalParams) -> LogicHandler:
467+
num_args = _count_function_args(f)
468+
457469
@functools.wraps(f)
458470
async def wrapper(value, origin: Optional[List[Any]] = None) -> None:
459471
if origin is None:
@@ -467,7 +479,12 @@ async def wrapper(value, origin: Optional[List[Any]] = None) -> None:
467479
logger.info("Triggering logic handler %s() from %s", f.__name__, origin)
468480
try:
469481
token = magicOriginVar.set([wrapper] if reset_origin else (origin + [wrapper]))
470-
await f(value, origin)
482+
if num_args == 0:
483+
await f() # type: ignore
484+
elif num_args == 1:
485+
await f(value) # type: ignore
486+
else:
487+
await f(value, origin) # type: ignore
471488
magicOriginVar.reset(token)
472489
except Exception as e:
473490
logger.error("Error while executing handler %s():", f.__name__, exc_info=e)
@@ -486,9 +503,12 @@ def blocking_handler() -> Callable[[Callable[[T, List[Any]], None]], LogicHandle
486503
Like :func:`handler`, this decorator catches and logs errors and ensures that the `origin` can magically be passed
487504
when called directly by other logic handlers. However, since the wrapped function is not an asynchronous coroutine,
488505
it is not able to call :meth:`Writable.write` or another logic handler directly. Thus, this decorator does not
489-
include special measures for preparing and passing the `origin` list or avoiding recursive execution.
506+
include special measures for preparing and passing the `origin` list or avoiding recursive execution. Still, it
507+
takes care of the correct number of arguments (zero to two) for calling the function.
490508
"""
491509
def decorator(f: Callable[[T, List[Any]], None]) -> LogicHandler:
510+
num_args = _count_function_args(f)
511+
492512
@functools.wraps(f)
493513
async def wrapper(value, origin: Optional[List[Any]] = None) -> None:
494514
if origin is None:
@@ -499,8 +519,19 @@ async def wrapper(value, origin: Optional[List[Any]] = None) -> None:
499519
logger.info("Triggering blocking logic handler %s() from %s", f.__name__, origin)
500520
try:
501521
loop = asyncio.get_event_loop()
502-
await loop.run_in_executor(None, f, value, origin)
522+
args = (value, origin)[:num_args]
523+
await loop.run_in_executor(None, f, *args)
503524
except Exception as e:
504525
logger.error("Error while executing handler %s():", f.__name__, exc_info=e)
505526
return wrapper
506527
return decorator
528+
529+
530+
def _count_function_args(f: Callable) -> int:
531+
num_args = 0
532+
for param in inspect.signature(f).parameters.values():
533+
if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
534+
num_args += 1
535+
elif param.kind is inspect.Parameter.VAR_POSITIONAL:
536+
num_args += 2**30
537+
return num_args

test/_helper.py

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import concurrent.futures
1313
import functools
1414
import heapq
15+
import inspect
1516
import threading
1617
import time
1718
import unittest.mock
@@ -49,6 +50,11 @@ class does not have the assert_awaited features of the official AsyncMock.
4950
The async calls are passed to the normal call/enter/exit methods of the super class to use its usual builtin
5051
evaluation/assertion functionality (e.g. :meth:`unittest.mock.NonCallableMock.assert_called_with`).
5152
"""
53+
def __init__(self, *args, **kwargs):
54+
super().__init__(*args, **kwargs)
55+
if not hasattr(self, "__signature__"):
56+
self.__signature__ = inspect.signature(self.__call__)
57+
5258
async def __call__(self, *args, **kwargs):
5359
return super().__call__(*args, **kwargs)
5460

test/test_base.py

+48
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,30 @@ async def test_handler(value, _origin) -> None:
160160
await asyncio.sleep(0.01)
161161
self.assertIn("unexpected error in _write", "\n".join(ctx.output))
162162

163+
@async_test
164+
async def test_missing_parameters(self) -> None:
165+
mock = AsyncMock()
166+
167+
@base.handler()
168+
async def full_handler(value, origin) -> None:
169+
await mock(value, origin)
170+
171+
@base.handler()
172+
async def part_handler(value) -> None:
173+
await mock(value)
174+
175+
@base.handler()
176+
async def empty_handler() -> None:
177+
await mock()
178+
179+
await full_handler(1, [self])
180+
await part_handler(2, [self, object])
181+
await empty_handler(3, [self, unittest.mock.sentinel])
182+
await asyncio.sleep(0.01)
183+
mock.assert_has_calls([unittest.mock.call(1, [self]),
184+
unittest.mock.call(2),
185+
unittest.mock.call()])
186+
163187

164188
class TestBlockingHandler(unittest.TestCase):
165189
@async_test
@@ -198,6 +222,30 @@ def blocking_test_handler(value, _origin):
198222
await asyncio.sleep(0.01)
199223
mock.assert_called_once_with(TOTALLY_RANDOM_NUMBER, [self, a, test_handler])
200224

225+
@async_test
226+
async def test_missing_parameters(self) -> None:
227+
mock = unittest.mock.MagicMock()
228+
229+
@base.blocking_handler()
230+
def full_handler(value, origin) -> None:
231+
mock(value, origin)
232+
233+
@base.blocking_handler()
234+
def part_handler(value) -> None:
235+
mock(value)
236+
237+
@base.blocking_handler()
238+
def empty_handler() -> None:
239+
mock()
240+
241+
await full_handler(1, [self])
242+
await part_handler(2, [self, object])
243+
await empty_handler(3, [self, unittest.mock.sentinel])
244+
await asyncio.sleep(0.01)
245+
mock.assert_has_calls([unittest.mock.call(1, [self]),
246+
unittest.mock.call(2),
247+
unittest.mock.call()])
248+
201249

202250
class TestReading(unittest.TestCase):
203251
@async_test

0 commit comments

Comments
 (0)