Skip to content

Commit 28427b8

Browse files
authored
Fix requires_feature decorator breaking named arguments (#21)
1 parent c44a126 commit 28427b8

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

letpot/deviceclient.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from datetime import time
1111
from functools import wraps
1212
from hashlib import md5, sha256
13-
from typing import Any, Callable, Concatenate
13+
from typing import Any, Callable, ParamSpec, TypeVar, cast
1414

1515
import aiomqtt
1616

@@ -32,32 +32,26 @@
3232

3333
_LOGGER = logging.getLogger(__name__)
3434

35+
T = TypeVar("T", bound="LetPotDeviceClient")
36+
_R = TypeVar("_R")
37+
P = ParamSpec("P")
3538

36-
def _create_ssl_context() -> ssl.SSLContext:
37-
"""Create a SSL context for the MQTT connection, avoids a blocking call later."""
38-
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
39-
context.load_default_certs()
40-
return context
4139

42-
43-
_SSL_CONTEXT = _create_ssl_context()
44-
45-
46-
def requires_feature[T: "LetPotDeviceClient", _R, **P](
40+
def requires_feature(
4741
*required_feature: DeviceFeature,
4842
) -> Callable[
49-
[Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]]],
50-
Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]],
43+
[Callable[P, Coroutine[Any, Any, _R]]],
44+
Callable[P, Coroutine[Any, Any, _R]],
5145
]:
5246
"""Decorate the function to require device type support for a specific feature (inferred from serial)."""
5347

5448
def decorator(
55-
func: Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]],
56-
) -> Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]]:
49+
func: Callable[P, Coroutine[Any, Any, _R]],
50+
) -> Callable[P, Coroutine[Any, Any, _R]]:
5751
@wraps(func)
58-
async def wrapper(
59-
self: T, serial: str, *args: P.args, **kwargs: P.kwargs
60-
) -> _R:
52+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> _R:
53+
self = cast(LetPotDeviceClient, args[0])
54+
serial = cast(str, args[1])
6155
exception_message = f"Device missing required feature: {required_feature}"
6256
try:
6357
supported_features = self._converter(serial).supported_features()
@@ -67,13 +61,23 @@ async def wrapper(
6761
raise LetPotFeatureException(exception_message)
6862
except LetPotException:
6963
raise LetPotFeatureException(exception_message)
70-
return await func(self, serial, *args, **kwargs)
64+
return await func(*args, **kwargs)
7165

7266
return wrapper
7367

7468
return decorator
7569

7670

71+
def _create_ssl_context() -> ssl.SSLContext:
72+
"""Create a SSL context for the MQTT connection, avoids a blocking call later."""
73+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
74+
context.load_default_certs()
75+
return context
76+
77+
78+
_SSL_CONTEXT = _create_ssl_context()
79+
80+
7781
class LetPotDeviceClient:
7882
"""Client for connecting to LetPot device."""
7983

0 commit comments

Comments
 (0)