From e26b6a116f17d40a19b478065ec25be02cd099e1 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Thu, 30 Nov 2023 12:07:00 +0000 Subject: [PATCH 1/4] Add support for TAPO/SMART KLAP and seperate transports from protocols --- devtools/dump_devinfo.py | 10 +- kasa/__init__.py | 4 +- kasa/{aesprotocol.py => aestransport.py} | 297 +++++-------------- kasa/discover.py | 27 +- kasa/iotprotocol.py | 106 +++++++ kasa/{klapprotocol.py => klaptransport.py} | 214 ++++++------- kasa/protocol.py | 50 ++++ kasa/smartprotocol.py | 223 ++++++++++++++ kasa/tapo/tapodevice.py | 6 +- kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json | 28 ++ kasa/tests/test_device_factory.py | 6 +- kasa/tests/test_klapprotocol.py | 86 +++--- 12 files changed, 669 insertions(+), 388 deletions(-) rename kasa/{aesprotocol.py => aestransport.py} (51%) create mode 100755 kasa/iotprotocol.py rename kasa/{klapprotocol.py => klaptransport.py} (78%) mode change 100755 => 100644 create mode 100644 kasa/smartprotocol.py create mode 100644 kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json diff --git a/devtools/dump_devinfo.py b/devtools/dump_devinfo.py index 777ee1050..76d750dcf 100644 --- a/devtools/dump_devinfo.py +++ b/devtools/dump_devinfo.py @@ -189,16 +189,12 @@ async def get_legacy_fixture(device): async def get_smart_fixture(device: SmartDevice): """Get fixture for new TAPO style protocol.""" items = [ - Call(module="component_nego", method="component_nego"), - Call(module="device_info", method="get_device_info"), - Call(module="device_usage", method="get_device_usage"), - Call(module="device_time", method="get_device_time"), - Call(module="energy_usage", method="get_energy_usage"), - Call(module="current_power", method="get_current_power"), Call( module="child_device_component_list", method="get_child_device_component_list", ), + Call(module="device_info", method="get_device_info"), + Call(module="device_usage", method="get_device_usage"), ] successes = [] @@ -250,7 +246,7 @@ async def get_smart_fixture(device: SmartDevice): model = final["get_device_info"]["model"] sw_version = sw_version.split(" ", maxsplit=1)[0] - return f"{model}.smart_{hw_version}_{sw_version}.json", final + return f"{model}_{hw_version}_{sw_version}.json", final if __name__ == "__main__": diff --git a/kasa/__init__.py b/kasa/__init__.py index 989e507f2..96576062e 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -21,7 +21,7 @@ SmartDeviceException, UnsupportedDeviceException, ) -from kasa.klapprotocol import TPLinkKlap +from kasa.iotprotocol import TPLinkIotProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.smartdevice import DeviceType, SmartDevice @@ -37,7 +37,7 @@ "Discover", "TPLinkSmartHomeProtocol", "TPLinkProtocol", - "TPLinkKlap", + "TPLinkIotProtocol", "SmartBulb", "SmartBulbPreset", "TurnOnBehaviors", diff --git a/kasa/aesprotocol.py b/kasa/aestransport.py similarity index 51% rename from kasa/aesprotocol.py rename to kasa/aestransport.py index 98776ce2a..b90dccd83 100644 --- a/kasa/aesprotocol.py +++ b/kasa/aestransport.py @@ -1,4 +1,4 @@ -"""Implementation of the TP-Link AES Protocol. +"""Implementation of the TP-Link AES transport. Based on the work of https://github.com/petretiandrea/plugp100 under compatible GNU GPL3 license. @@ -9,12 +9,10 @@ import hashlib import logging import time -import uuid -from pprint import pformat as pf -from typing import Dict, Optional, Union +from typing import Optional import httpx -from cryptography.hazmat.primitives import hashes, padding, serialization +from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -23,33 +21,24 @@ from .exceptions import AuthenticationException, SmartDeviceException from .json import dumps as json_dumps from .json import loads as json_loads -from .protocol import TPLinkProtocol +from .protocol import TPLinkTransport _LOGGER = logging.getLogger(__name__) -logging.getLogger("httpx").propagate = False -def _md5(payload: bytes) -> bytes: - digest = hashes.Hash(hashes.MD5()) # noqa: S303 - digest.update(payload) - hash = digest.finalize() - return hash - - -def _sha1(payload: bytes) -> str: +def _sha1_hex(payload: bytes) -> str: sha1_algo = hashlib.sha1() # noqa: S324 sha1_algo.update(payload) return sha1_algo.hexdigest() -class TPLinkAes(TPLinkProtocol): +class TPLinkAesTransport(TPLinkTransport): """Implementation of the AES encryption protocol. AES is the name used in device discovery for TP-Link's TAPO encryption protocol, sometimes used by newer firmware versions on kasa devices. """ - DEFAULT_PORT = 80 DEFAULT_TIMEOUT = 5 SESSION_COOKIE_NAME = "TP_SESSIONID" COMMON_HEADERS = { @@ -65,7 +54,7 @@ def __init__( credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: - super().__init__(host=host, port=self.DEFAULT_PORT) + super().__init__(host=host) self.credentials = ( credentials @@ -74,12 +63,10 @@ def __init__( ) self._local_seed: Optional[bytes] = None - self.local_auth_hash = self.generate_auth_hash(self.credentials) - self.local_auth_owner = self.generate_owner_hash(self.credentials).hex() self.kasa_setup_auth_hash = None self.blank_auth_hash = None self.handshake_lock = asyncio.Lock() - self.query_lock = asyncio.Lock() + self.handshake_done = False self.encryption_session: Optional[AesEncyptionSession] = None @@ -87,9 +74,8 @@ def __init__( self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT self.session_cookie = None - self.terminal_uuid = None - self.http_client: Optional[httpx.AsyncClient] = None - self.request_id_generator = SnowflakeId(1, 1) + + self.http_client: httpx.AsyncClient = httpx.AsyncClient() self.login_token = None _LOGGER.debug("Created AES object for %s", self.host) @@ -98,14 +84,14 @@ def hash_credentials(self, credentials, try_login_version2): """Hash the credentials.""" if try_login_version2: un = base64.b64encode( - _sha1(credentials.username.encode()).encode() + _sha1_hex(credentials.username.encode()).encode() ).decode() pw = base64.b64encode( - _sha1(credentials.password.encode()).encode() + _sha1_hex(credentials.password.encode()).encode() ).decode() else: un = base64.b64encode( - _sha1(credentials.username.encode()).encode() + _sha1_hex(credentials.username.encode()).encode() ).decode() pw = base64.b64encode(credentials.password.encode()).decode() return un, pw @@ -132,55 +118,70 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None return resp.status_code, response_data - async def send_secure_passthrough(self, request): + async def send_secure_passthrough(self, request: str): """Send encrypted message as passthrough.""" url = f"http://{self.host}/app" if self.login_token: url += f"?token={self.login_token}" - raw_request = json_dumps(request) - encrypted_payload = self.encryption_session.encrypt(raw_request.encode()) + + encrypted_payload = self.encryption_session.encrypt(request.encode()) # type: ignore passthrough_request = { "method": "securePassthrough", "params": {"request": encrypted_payload.decode()}, } status_code, resp_dict = await self.client_post(url, json=passthrough_request) + _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}") if status_code == 200 and resp_dict["error_code"] == 0: - response = self.encryption_session.decrypt( + response = self.encryption_session.decrypt( # type: ignore resp_dict["result"]["response"].encode() ) + _LOGGER.debug(f"decrypted secure_passthrough response is {response}") resp_dict = json_loads(response) - if resp_dict["error_code"] != 0: - raise SmartDeviceException( - f"Could not complete send, response was {resp_dict}", - ) - if "result" in resp_dict: - return resp_dict["result"] + return resp_dict else: + self.handshake_done = False + self.login_token = None raise AuthenticationException("Could not complete send") - def get_aes_request(self, method, params=None): - """Get a request message.""" - request = { - "method": method, - "params": params, - "requestID": self.request_id_generator.generate_id(), - "request_time_milis": round(time.time() * 1000), - "terminal_uuid": self.terminal_uuid, - } - return request - - async def perform_login(self, login_v2): + async def perform_login(self, login_request, login_v2): """Login to the device.""" self.login_token = None + if isinstance(login_request, str): + login_request = json_loads(login_request) + un, pw = self.hash_credentials(self.credentials, login_v2) - params = {"password": pw, "username": un} - request = self.get_aes_request("login_device", params) + login_request["params"] = {"password": pw, "username": un} + request = json_dumps(login_request) try: - result = await self.send_secure_passthrough(request) + resp_dict = await self.send_secure_passthrough(request) except SmartDeviceException as ex: raise AuthenticationException(ex) from ex - self.login_token = result["token"] + self.login_token = resp_dict["result"]["token"] + + def needs_login(self) -> bool: + """Return true if the transport needs to do a login.""" + return self.login_token is None + + async def login(self, request: str) -> None: + """Login to the device.""" + try: + if self.needs_handshake(): + raise SmartDeviceException( + "Handshake must be complete before trying to login" + ) + await self.perform_login(request, False) + except AuthenticationException: + await self.perform_handshake() + await self.perform_login(request, True) + + def needs_handshake(self) -> bool: + """Return true if the transport needs to do a handshake.""" + return not self.handshake_done or self.handshake_session_expired() + + async def handshake(self) -> None: + """Perform the encryption handshake.""" + await self.perform_handshake() async def perform_handshake(self): """Perform the handshake.""" @@ -227,9 +228,6 @@ async def perform_handshake(self): handshake_key, key_pair ) - self.terminal_uuid = base64.b64encode(_md5(uuid.uuid4().bytes)).decode( - "UTF-8" - ) self.handshake_done = True _LOGGER.debug("Handshake with %s complete", self.host) @@ -243,96 +241,23 @@ def handshake_session_expired(self): self.session_expire_at is None or self.session_expire_at - time.time() <= 0 ) - @staticmethod - def generate_auth_hash(creds: Credentials): - """Generate an md5 auth hash for the protocol on the supplied credentials.""" - un = creds.username or "" - pw = creds.password or "" - return _md5(_md5(un.encode()) + _md5(pw.encode())) - - @staticmethod - def generate_owner_hash(creds: Credentials): - """Return the MD5 hash of the username in this object.""" - un = creds.username or "" - return _md5(un.encode()) - - async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: - """Query the device retrying for retry_count on failure.""" - async with self.query_lock: - return await self._query(request, retry_count) - - async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: - for retry in range(retry_count + 1): - try: - return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {sdex}" - ) from sdex - continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {cex}" - ) from cex - except TimeoutError as tex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device, timed out: {self.host}: {tex}" - ) from tex - except AuthenticationException as auex: - _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) - raise auex - except Exception as ex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" - ) from ex - continue - - # make mypy happy, this should never be reached.. - raise SmartDeviceException("Query reached somehow to unreachable") - - async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: - _LOGGER.debug( - "%s >> %s", - self.host, - _LOGGER.isEnabledFor(logging.DEBUG) and pf(request), - ) - - if not self.http_client: - self.http_client = httpx.AsyncClient() - - if not self.handshake_done or self.handshake_session_expired(): - try: - await self.perform_handshake() - await self.perform_login(False) - except AuthenticationException: - await self.perform_handshake() - await self.perform_login(True) - - if isinstance(request, dict): - aes_method = next(iter(request)) - aes_params = request[aes_method] - else: - aes_method = request - aes_params = None - - aes_request = self.get_aes_request(aes_method, aes_params) - response_data = await self.send_secure_passthrough(aes_request) - - _LOGGER.debug( - "%s << %s", - self.host, - _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), - ) - - return response_data + async def send(self, request: str): + """Send the request.""" + if self.needs_handshake(): + raise SmartDeviceException( + "Handshake must be complete before trying to send" + ) + if self.needs_login(): + raise SmartDeviceException("Login must be complete before trying to send") + + resp_dict = await self.send_secure_passthrough(request) + if resp_dict["error_code"] != 0: + self.handshake_done = False + self.login_token = None + raise SmartDeviceException( + f"Could not complete send, response was {resp_dict}", + ) + return resp_dict async def close(self) -> None: """Close the protocol.""" @@ -416,83 +341,3 @@ def get_private_key(self) -> str: def get_public_key(self) -> str: """Get the public key.""" return self.public_key - - -class SnowflakeId: - """Class for generating snowflake ids.""" - - EPOCH = 1420041600000 # Custom epoch (in milliseconds) - WORKER_ID_BITS = 5 - DATA_CENTER_ID_BITS = 5 - SEQUENCE_BITS = 12 - - MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1 - MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1 - - SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1 - - def __init__(self, worker_id, data_center_id): - if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0: - raise ValueError( - "Worker ID can't be greater than " - + str(SnowflakeId.MAX_WORKER_ID) - + " or less than 0" - ) - if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0: - raise ValueError( - "Data center ID can't be greater than " - + str(SnowflakeId.MAX_DATA_CENTER_ID) - + " or less than 0" - ) - - self.worker_id = worker_id - self.data_center_id = data_center_id - self.sequence = 0 - self.last_timestamp = -1 - - def generate_id(self): - """Generate a snowflake id.""" - timestamp = self._current_millis() - - if timestamp < self.last_timestamp: - raise ValueError("Clock moved backwards. Refusing to generate ID.") - - if timestamp == self.last_timestamp: - # Within the same millisecond, increment the sequence number - self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK - if self.sequence == 0: - # Sequence exceeds its bit range, wait until the next millisecond - timestamp = self._wait_next_millis(self.last_timestamp) - else: - # New millisecond, reset the sequence number - self.sequence = 0 - - # Update the last timestamp - self.last_timestamp = timestamp - - # Generate and return the final ID - return ( - ( - (timestamp - SnowflakeId.EPOCH) - << ( - SnowflakeId.WORKER_ID_BITS - + SnowflakeId.SEQUENCE_BITS - + SnowflakeId.DATA_CENTER_ID_BITS - ) - ) - | ( - self.data_center_id - << (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS) - ) - | (self.worker_id << SnowflakeId.SEQUENCE_BITS) - | self.sequence - ) - - def _current_millis(self): - return round(time.time() * 1000) - - def _wait_next_millis(self, last_timestamp): - timestamp = self._current_millis() - while timestamp <= last_timestamp: - timestamp = self._current_millis() - return timestamp diff --git a/kasa/discover.py b/kasa/discover.py index 59849bc0e..d50445c9e 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -4,7 +4,7 @@ import ipaddress import logging import socket -from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast +from typing import Awaitable, Callable, Dict, Optional, Set, Tuple, Type, cast # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout @@ -15,15 +15,16 @@ except ImportError: from pydantic import BaseModel, Field -from kasa.aesprotocol import TPLinkAes from kasa.credentials import Credentials from kasa.exceptions import UnsupportedDeviceException +from kasa.iotprotocol import TPLinkIotProtocol from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads -from kasa.klapprotocol import TPLinkKlap -from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol +from kasa.klaptransport import TPLinkKlapTransport, TPlinkKlapTransportV2 +from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, TPLinkTransport from kasa.smartdevice import SmartDevice, SmartDeviceException from kasa.smartplug import SmartPlug +from kasa.smartprotocol import TPLinkAesTransport, TPLinkSmartProtocol from kasa.tapo.tapoplug import TapoPlug from .device_factory import get_device_class_from_info @@ -391,9 +392,12 @@ def _get_device_instance( "SMART.KASAPLUG": TapoPlug, "IOT.SMARTPLUGSWITCH": SmartPlug, } - supported_device_protocols: dict[str, Type[TPLinkProtocol]] = { - "IOT.KLAP": TPLinkKlap, - "SMART.AES": TPLinkAes, + supported_device_protocols: dict[ + str, Tuple[Type[TPLinkProtocol], Type[TPLinkTransport]] + ] = { + "IOT.KLAP": (TPLinkIotProtocol, TPLinkKlapTransport), + "SMART.AES": (TPLinkSmartProtocol, TPLinkAesTransport), + "SMART.KLAP": (TPLinkSmartProtocol, TPlinkKlapTransportV2), } if (device_class := supported_device_types.get(type_)) is None: @@ -401,7 +405,9 @@ def _get_device_instance( raise UnsupportedDeviceException( f"Unsupported device {ip} of type {type_}: {info}" ) - if (protocol_class := supported_device_protocols.get(encrypt_type_)) is None: + if ( + protocol_transport_tuple := supported_device_protocols.get(encrypt_type_) + ) is None: _LOGGER.warning("Got unsupported device type: %s", encrypt_type_) raise UnsupportedDeviceException( f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}" @@ -409,7 +415,10 @@ def _get_device_instance( _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) device = device_class(ip, port=port, credentials=credentials) - device.protocol = protocol_class(ip, credentials=credentials) + transport = protocol_transport_tuple[1](ip, credentials=credentials) + device.protocol = protocol_transport_tuple[0]( + ip, credentials=credentials, transport=transport + ) device.update_from_discover_info(discovery_result.get_dict()) return device diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py new file mode 100755 index 000000000..808594540 --- /dev/null +++ b/kasa/iotprotocol.py @@ -0,0 +1,106 @@ +"""Module for the IOT legacy IOT KASA protocol.""" +import asyncio +import logging +from typing import Dict, Optional, Union + +import httpx + +from .credentials import Credentials +from .exceptions import AuthenticationException, SmartDeviceException +from .json import dumps as json_dumps +from .klaptransport import TPLinkKlapTransport +from .protocol import TPLinkProtocol, TPLinkTransport + +_LOGGER = logging.getLogger(__name__) + + +class TPLinkIotProtocol(TPLinkProtocol): + """Class for the legacy TPLink IOT KASA Protocol.""" + + DEFAULT_PORT = 80 + + def __init__( + self, + host: str, + *, + transport: Optional[TPLinkTransport] = None, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host=host, port=self.DEFAULT_PORT) + + self.credentials: Credentials = ( + credentials + if credentials and credentials.username and credentials.password + else Credentials(username="", password="") + ) + self.transport: TPLinkTransport = ( + transport + if transport + else TPLinkKlapTransport( + host, credentials=self.credentials, timeout=timeout + ) + ) + + self.query_lock = asyncio.Lock() + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + """Query the device retrying for retry_count on failure.""" + if isinstance(request, dict): + request = json_dumps(request) + assert isinstance(request, str) # noqa: S101 + + async with self.query_lock: + return await self._query(request, retry_count) + + async def _query(self, request: str, retry_count: int = 3) -> Dict: + for retry in range(retry_count + 1): + try: + return await self._execute_query(request, retry) + except httpx.CloseError as sdex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {sdex}" + ) from sdex + continue + except httpx.ConnectError as cex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {cex}" + ) from cex + except TimeoutError as tex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device, timed out: {self.host}: {tex}" + ) from tex + except AuthenticationException as auex: + _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) + raise auex + except Exception as ex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {ex}" + ) from ex + continue + + # make mypy happy, this should never be reached.. + raise SmartDeviceException("Query reached somehow to unreachable") + + async def _execute_query(self, request: str, retry_count: int) -> Dict: + if self.transport.needs_handshake(): + await self.transport.handshake() + + if self.transport.needs_login(): # This shouln't happen + raise SmartDeviceException( + "IOT Protocol needs to login to transport but is not login aware" + ) + + return await self.transport.send(request) + + async def close(self) -> None: + """Close the protocol.""" + await self.transport.close() diff --git a/kasa/klapprotocol.py b/kasa/klaptransport.py old mode 100755 new mode 100644 similarity index 78% rename from kasa/klapprotocol.py rename to kasa/klaptransport.py index 36a42c589..4ad37d906 --- a/kasa/klapprotocol.py +++ b/kasa/klaptransport.py @@ -47,7 +47,7 @@ import secrets import time from pprint import pformat as pf -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple import httpx from cryptography.hazmat.primitives import hashes, padding @@ -55,33 +55,33 @@ from .credentials import Credentials from .exceptions import AuthenticationException, SmartDeviceException -from .json import dumps as json_dumps from .json import loads as json_loads -from .protocol import TPLinkProtocol +from .protocol import TPLinkTransport, md5 _LOGGER = logging.getLogger(__name__) logging.getLogger("httpx").propagate = False def _sha256(payload: bytes) -> bytes: - return hashlib.sha256(payload).digest() - - -def _md5(payload: bytes) -> bytes: - digest = hashes.Hash(hashes.MD5()) # noqa: S303 + digest = hashes.Hash(hashes.SHA256()) # noqa: S303 digest.update(payload) hash = digest.finalize() return hash -class TPLinkKlap(TPLinkProtocol): +def _sha1(payload: bytes) -> bytes: + digest = hashes.Hash(hashes.SHA1()) # noqa: S303 + digest.update(payload) + return digest.finalize() + + +class TPLinkKlapTransport(TPLinkTransport): """Implementation of the KLAP encryption protocol. KLAP is the name used in device discovery for TP-Link's new encryption protocol, used by newer firmware versions. """ - DEFAULT_PORT = 80 DEFAULT_TIMEOUT = 5 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} KASA_SETUP_EMAIL = "kasa@tp-link.net" @@ -95,14 +95,13 @@ def __init__( credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: - super().__init__(host=host, port=self.DEFAULT_PORT) + super().__init__(host=host) self.credentials = ( credentials if credentials and credentials.username and credentials.password else Credentials(username="", password="") ) - self._local_seed: Optional[bytes] = None self.local_auth_hash = self.generate_auth_hash(self.credentials) self.local_auth_owner = self.generate_owner_hash(self.credentials).hex() @@ -117,7 +116,7 @@ def __init__( self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT self.session_cookie = None - self.http_client: Optional[httpx.AsyncClient] = None + self.http_client: httpx.AsyncClient = httpx.AsyncClient() _LOGGER.debug("Created KLAP object for %s", self.host) @@ -183,7 +182,9 @@ async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]: server_hash.hex(), ) - local_seed_auth_hash = _sha256(local_seed + self.local_auth_hash) + local_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, remote_seed, self.local_auth_hash + ) # type: ignore # Check the response from the device with local credentials if local_seed_auth_hash == server_hash: @@ -193,14 +194,17 @@ async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]: # Now check against the default kasa setup credentials if not self.kasa_setup_auth_hash: kasa_setup_creds = Credentials( - username=TPLinkKlap.KASA_SETUP_EMAIL, - password=TPLinkKlap.KASA_SETUP_PASSWORD, + username=self.KASA_SETUP_EMAIL, + password=self.KASA_SETUP_PASSWORD, ) - self.kasa_setup_auth_hash = TPLinkKlap.generate_auth_hash(kasa_setup_creds) + self.kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds) - kasa_setup_seed_auth_hash = _sha256( - local_seed + self.kasa_setup_auth_hash # type: ignore + kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, + remote_seed, + self.kasa_setup_auth_hash, # type: ignore ) + if kasa_setup_seed_auth_hash == server_hash: _LOGGER.debug( "Server response doesn't match our expected hash on ip %s" @@ -212,8 +216,14 @@ async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]: # Finally check against blank credentials if not already blank if self.credentials != (blank_creds := Credentials(username="", password="")): if not self.blank_auth_hash: - self.blank_auth_hash = TPLinkKlap.generate_auth_hash(blank_creds) - blank_seed_auth_hash = _sha256(local_seed + self.blank_auth_hash) # type: ignore + self.blank_auth_hash = self.generate_auth_hash(blank_creds) + + blank_seed_auth_hash = self.handshake1_seed_auth_hash( + local_seed, + remote_seed, + self.blank_auth_hash, # type: ignore + ) + if blank_seed_auth_hash == server_hash: _LOGGER.debug( "Server response doesn't match our expected hash on ip %s" @@ -235,7 +245,7 @@ async def perform_handshake2( url = f"http://{self.host}/app/handshake2" - payload = _sha256(remote_seed + auth_hash) + payload = self.handshake2_seed_auth_hash(local_seed, remote_seed, auth_hash) response_status, response_data = await self.client_post(url, data=payload) @@ -256,6 +266,24 @@ async def perform_handshake2( return KlapEncryptionSession(local_seed, remote_seed, auth_hash) + def needs_login(self) -> bool: + """Will return false as KLAP does not do a login.""" + return False + + async def login(self, request: str) -> None: + """Will raise and exception as KLAP does not do a login.""" + raise SmartDeviceException( + "KLAP does not perform logins and return needs_login == False" + ) + + def needs_handshake(self): + """Return true if the transport needs to do a handshake.""" + return not self.handshake_done or self.handshake_session_expired() + + async def handshake(self) -> None: + """Perform the encryption handshake.""" + await self.perform_handshake() + async def perform_handshake(self) -> Any: """Perform handshake1 and handshake2. @@ -268,7 +296,7 @@ async def perform_handshake(self) -> Any: local_seed, remote_seed, auth_hash = await self.perform_handshake1() self.session_cookie = self.http_client.cookies.get( # type: ignore - TPLinkKlap.SESSION_COOKIE_NAME + self.SESSION_COOKIE_NAME ) # The device returns a TIMEOUT cookie on handshake1 which # it doesn't like to get back so we store the one we want @@ -287,80 +315,14 @@ def handshake_session_expired(self): self.session_expire_at is None or self.session_expire_at - time.time() <= 0 ) - @staticmethod - def generate_auth_hash(creds: Credentials): - """Generate an md5 auth hash for the protocol on the supplied credentials.""" - un = creds.username or "" - pw = creds.password or "" - return _md5(_md5(un.encode()) + _md5(pw.encode())) - - @staticmethod - def generate_owner_hash(creds: Credentials): - """Return the MD5 hash of the username in this object.""" - un = creds.username or "" - return _md5(un.encode()) - - async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: - """Query the device retrying for retry_count on failure.""" - if isinstance(request, dict): - request = json_dumps(request) - assert isinstance(request, str) # noqa: S101 - - async with self.query_lock: - return await self._query(request, retry_count) - - async def _query(self, request: str, retry_count: int = 3) -> Dict: - for retry in range(retry_count + 1): - try: - return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {sdex}" - ) from sdex - continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {cex}" - ) from cex - except TimeoutError as tex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device, timed out: {self.host}: {tex}" - ) from tex - except AuthenticationException as auex: - _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) - raise auex - except Exception as ex: - await self.close() - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) - raise SmartDeviceException( - f"Unable to connect to the device: {self.host}: {ex}" - ) from ex - continue - - # make mypy happy, this should never be reached.. - raise SmartDeviceException("Query reached somehow to unreachable") - - async def _execute_query(self, request: str, retry_count: int) -> Dict: - if not self.http_client: - self.http_client = httpx.AsyncClient() - - if not self.handshake_done or self.handshake_session_expired(): - try: - await self.perform_handshake() - - except AuthenticationException as auex: - _LOGGER.debug( - "Unable to complete handshake for device %s, " - + "authentication failed", - self.host, - ) - raise auex + async def send(self, request: str): + """Send the request.""" + if self.needs_handshake(): + raise SmartDeviceException( + "Handshake must be complete before trying to send" + ) + if self.needs_login(): + raise SmartDeviceException("Login must be complete before trying to send") # Check for mypy if self.encryption_session is not None: @@ -376,7 +338,7 @@ async def _execute_query(self, request: str, retry_count: int) -> Dict: msg = ( f"at {datetime.datetime.now()}. Host is {self.host}, " - + f"Retry count is {retry_count}, Sequence is {seq}, " + + f"Sequence is {seq}, " + f"Response status is {response_status}, Request was {request}" ) if response_status != 200: @@ -411,12 +373,66 @@ async def _execute_query(self, request: str, retry_count: int) -> Dict: return json_payload async def close(self) -> None: - """Close the protocol.""" + """Close the transport.""" client = self.http_client self.http_client = None if client: await client.aclose() + @staticmethod + def generate_auth_hash(creds: Credentials): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + un = creds.username or "" + pw = creds.password or "" + + return md5(md5(un.encode()) + md5(pw.encode())) + + @staticmethod + def handshake1_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(local_seed + auth_hash) + + @staticmethod + def handshake2_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(remote_seed + auth_hash) + + @staticmethod + def generate_owner_hash(creds: Credentials): + """Return the MD5 hash of the username in this object.""" + un = creds.username or "" + return md5(un.encode()) + + +class TPlinkKlapTransportV2(TPLinkKlapTransport): + """Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" + + @staticmethod + def generate_auth_hash(creds: Credentials): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + un = creds.username or "" + pw = creds.password or "" + + return _sha256(_sha1(un.encode()) + _sha1(pw.encode())) + + @staticmethod + def handshake1_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(local_seed + remote_seed + auth_hash) + + @staticmethod + def handshake2_seed_auth_hash( + local_seed: bytes, remote_seed: bytes, auth_hash: bytes + ): + """Generate an md5 auth hash for the protocol on the supplied credentials.""" + return _sha256(remote_seed + local_seed + auth_hash) + class KlapEncryptionSession: """Class to represent an encryption session and it's internal state. diff --git a/kasa/protocol.py b/kasa/protocol.py index 6413ba5de..acd08723d 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -22,6 +22,7 @@ # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout +from cryptography.hazmat.primitives import hashes from .credentials import Credentials from .exceptions import SmartDeviceException @@ -32,6 +33,54 @@ _NO_RETRY_ERRORS = {errno.EHOSTDOWN, errno.EHOSTUNREACH, errno.ECONNREFUSED} +def md5(payload: bytes) -> bytes: + """Return an md5 hash of the payload.""" + digest = hashes.Hash(hashes.MD5()) # noqa: S303 + digest.update(payload) + hash = digest.finalize() + return hash + + +class TPLinkTransport(ABC): + """Base class for all TP-Link KASA-KLAP and TAPO transports.""" + + def __init__( + self, + host: str, + *, + port: Optional[int] = None, + credentials: Optional[Credentials] = None, + ) -> None: + """Create a protocol object.""" + self.host = host + self.port = port + self.credentials = credentials + + @abstractmethod + def needs_handshake(self) -> bool: + """Return true if the transport needs to do a handshake.""" + + @abstractmethod + def needs_login(self) -> bool: + """Return true if the transport needs to do a login.""" + + @abstractmethod + async def login(self, request: str) -> None: + """Login to the device.""" + + @abstractmethod + async def handshake(self) -> None: + """Perform the encryption handshake.""" + + @abstractmethod + async def send(self, request: str) -> Dict: + """Send a message to the device and return a response.""" + + @abstractmethod + async def close(self) -> None: + """Close the transport. Abstract method to be overriden.""" + + class TPLinkProtocol(ABC): """Base class for all TP-Link Smart Home communication.""" @@ -41,6 +90,7 @@ def __init__( *, port: Optional[int] = None, credentials: Optional[Credentials] = None, + transport: Optional[TPLinkTransport] = None, ) -> None: """Create a protocol object.""" self.host = host diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py new file mode 100644 index 000000000..962c26ef7 --- /dev/null +++ b/kasa/smartprotocol.py @@ -0,0 +1,223 @@ +"""Implementation of the TP-Link AES Protocol. + +Based on the work of https://github.com/petretiandrea/plugp100 +under compatible GNU GPL3 license. +""" + +import asyncio +import base64 +import logging +import time +import uuid +from pprint import pformat as pf +from typing import Dict, Optional, Union + +import httpx + +from .aestransport import TPLinkAesTransport +from .credentials import Credentials +from .exceptions import AuthenticationException, SmartDeviceException +from .json import dumps as json_dumps +from .protocol import TPLinkProtocol, TPLinkTransport, md5 + +_LOGGER = logging.getLogger(__name__) +logging.getLogger("httpx").propagate = False + + +class TPLinkSmartProtocol(TPLinkProtocol): + """Class for the new TPLink SMART protocol.""" + + DEFAULT_PORT = 80 + + def __init__( + self, + host: str, + *, + transport: Optional[TPLinkTransport] = None, + credentials: Optional[Credentials] = None, + timeout: Optional[int] = None, + ) -> None: + super().__init__(host=host, port=self.DEFAULT_PORT) + + self.credentials: Credentials = ( + credentials + if credentials and credentials.username and credentials.password + else Credentials(username="", password="") + ) + self.transport: TPLinkTransport = ( + transport + if transport + else TPLinkAesTransport(host, credentials=self.credentials, timeout=timeout) + ) + self.terminal_uuid: Optional[str] = None + self.request_id_generator = SnowflakeId(1, 1) + self.query_lock = asyncio.Lock() + + def get_smart_request(self, method, params=None) -> str: + """Get a request message as a string.""" + request = { + "method": method, + "params": params, + "requestID": self.request_id_generator.generate_id(), + "request_time_milis": round(time.time() * 1000), + "terminal_uuid": self.terminal_uuid, + } + return json_dumps(request) + + async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + """Query the device retrying for retry_count on failure.""" + async with self.query_lock: + resp_dict = await self._query(request, retry_count) + if "result" in resp_dict: + return resp_dict["result"] + return {} + + async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: + for retry in range(retry_count + 1): + try: + return await self._execute_query(request, retry) + except httpx.CloseError as sdex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {sdex}" + ) from sdex + continue + except httpx.ConnectError as cex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {cex}" + ) from cex + except TimeoutError as tex: + await self.close() + raise SmartDeviceException( + f"Unable to connect to the device, timed out: {self.host}: {tex}" + ) from tex + except AuthenticationException as auex: + _LOGGER.debug("Unable to authenticate with %s, not retrying", self.host) + raise auex + except Exception as ex: + await self.close() + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self.host, retry) + raise SmartDeviceException( + f"Unable to connect to the device: {self.host}: {ex}" + ) from ex + continue + + # make mypy happy, this should never be reached.. + raise SmartDeviceException("Query reached somehow to unreachable") + + async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: + if isinstance(request, dict): + smart_method = next(iter(request)) + smart_params = request[smart_method] + else: + smart_method = request + smart_params = None + + if self.transport.needs_handshake(): + await self.transport.handshake() + + if self.transport.needs_login(): + self.terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode( + "UTF-8" + ) + login_request = self.get_smart_request("login_device") + await self.transport.login(login_request) + + smart_request = self.get_smart_request(smart_method, smart_params) + response_data = await self.transport.send(smart_request) + + _LOGGER.debug( + "%s << %s", + self.host, + _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), + ) + + return response_data + + async def close(self) -> None: + """Close the protocol.""" + await self.transport.close() + + +class SnowflakeId: + """Class for generating snowflake ids.""" + + EPOCH = 1420041600000 # Custom epoch (in milliseconds) + WORKER_ID_BITS = 5 + DATA_CENTER_ID_BITS = 5 + SEQUENCE_BITS = 12 + + MAX_WORKER_ID = (1 << WORKER_ID_BITS) - 1 + MAX_DATA_CENTER_ID = (1 << DATA_CENTER_ID_BITS) - 1 + + SEQUENCE_MASK = (1 << SEQUENCE_BITS) - 1 + + def __init__(self, worker_id, data_center_id): + if worker_id > SnowflakeId.MAX_WORKER_ID or worker_id < 0: + raise ValueError( + "Worker ID can't be greater than " + + str(SnowflakeId.MAX_WORKER_ID) + + " or less than 0" + ) + if data_center_id > SnowflakeId.MAX_DATA_CENTER_ID or data_center_id < 0: + raise ValueError( + "Data center ID can't be greater than " + + str(SnowflakeId.MAX_DATA_CENTER_ID) + + " or less than 0" + ) + + self.worker_id = worker_id + self.data_center_id = data_center_id + self.sequence = 0 + self.last_timestamp = -1 + + def generate_id(self): + """Generate a snowflake id.""" + timestamp = self._current_millis() + + if timestamp < self.last_timestamp: + raise ValueError("Clock moved backwards. Refusing to generate ID.") + + if timestamp == self.last_timestamp: + # Within the same millisecond, increment the sequence number + self.sequence = (self.sequence + 1) & SnowflakeId.SEQUENCE_MASK + if self.sequence == 0: + # Sequence exceeds its bit range, wait until the next millisecond + timestamp = self._wait_next_millis(self.last_timestamp) + else: + # New millisecond, reset the sequence number + self.sequence = 0 + + # Update the last timestamp + self.last_timestamp = timestamp + + # Generate and return the final ID + return ( + ( + (timestamp - SnowflakeId.EPOCH) + << ( + SnowflakeId.WORKER_ID_BITS + + SnowflakeId.SEQUENCE_BITS + + SnowflakeId.DATA_CENTER_ID_BITS + ) + ) + | ( + self.data_center_id + << (SnowflakeId.SEQUENCE_BITS + SnowflakeId.WORKER_ID_BITS) + ) + | (self.worker_id << SnowflakeId.SEQUENCE_BITS) + | self.sequence + ) + + def _current_millis(self): + return round(time.time() * 1000) + + def _wait_next_millis(self, last_timestamp): + timestamp = self._current_millis() + while timestamp <= last_timestamp: + timestamp = self._current_millis() + return timestamp diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 2ba039565..701063739 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -4,10 +4,10 @@ from datetime import datetime, timedelta, timezone from typing import Any, Dict, Optional, Set, cast -from ..aesprotocol import TPLinkAes from ..credentials import Credentials from ..exceptions import AuthenticationException from ..smartdevice import SmartDevice +from ..smartprotocol import TPLinkSmartProtocol _LOGGER = logging.getLogger(__name__) @@ -26,7 +26,9 @@ def __init__( super().__init__(host, port=port, credentials=credentials, timeout=timeout) self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None - self.protocol = TPLinkAes(host, credentials=credentials, timeout=timeout) + self.protocol = TPLinkSmartProtocol( + host, credentials=credentials, timeout=timeout + ) async def update(self, update_children: bool = True): """Update the device.""" diff --git a/kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json b/kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json new file mode 100644 index 000000000..787f367e2 --- /dev/null +++ b/kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json @@ -0,0 +1,28 @@ +{ + "system": { + "get_sysinfo": { + "active_mode": "schedule", + "alias": "Living Room Lamp", + "dev_name": "Wi-Fi Smart Plug", + "deviceId": "0000000000000000000000000000000000000000", + "err_code": 0, + "feature": "TIM", + "fwId": "00000000000000000000000000000000", + "hwId": "00000000000000000000000000000000", + "hw_ver": "1.0", + "icon_hash": "", + "latitude": 0, + "led_off": 0, + "longitude": 0, + "mac": "00:00:00:00:00:00", + "model": "HS100(UK)", + "oemId": "00000000000000000000000000000000", + "on_time": 4102, + "relay_state": 1, + "rssi": -58, + "sw_ver": "1.2.6 Build 200727 Rel.120236", + "type": "IOT.SMARTPLUGSWITCH", + "updating": 0 + } + } +} diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index aca38e19d..991e75bdc 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -14,7 +14,7 @@ SmartPlug, ) from kasa.device_factory import connect -from kasa.klapprotocol import TPLinkKlap +from kasa.iotprotocol import TPLinkIotProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol @@ -81,7 +81,7 @@ async def test_connect_logs_connect_time( ("protocol_in", "protocol_result"), ( (None, TPLinkSmartHomeProtocol), - (TPLinkKlap, TPLinkKlap), + (TPLinkIotProtocol, TPLinkIotProtocol), (TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol), ), ) @@ -95,7 +95,7 @@ async def test_connect_pass_protocol( """Test that if the protocol is passed in it's gets set correctly.""" host = "127.0.0.1" mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - mocker.patch("kasa.TPLinkKlap.query", return_value=discovery_data) + mocker.patch("kasa.TPLinkIotProtocol.query", return_value=discovery_data) dev = await connect(host, device_type=device_type, protocol_class=protocol_in) assert isinstance(dev.protocol, protocol_result) diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 991dbe6fa..2d3eb6c2a 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -12,7 +12,8 @@ from ..credentials import Credentials from ..exceptions import AuthenticationException, SmartDeviceException -from ..klapprotocol import KlapEncryptionSession, TPLinkKlap, _sha256 +from ..iotprotocol import TPLinkIotProtocol +from ..klaptransport import KlapEncryptionSession, TPLinkKlapTransport, _sha256 class _mock_response: @@ -24,34 +25,34 @@ def __init__(self, status_code, content: bytes): @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_retries(mocker, retry_count): conn = mocker.patch.object( - TPLinkKlap, "client_post", side_effect=Exception("dummy exception") + TPLinkKlapTransport, "client_post", side_effect=Exception("dummy exception") ) with pytest.raises(SmartDeviceException): - await TPLinkKlap("127.0.0.1").query({}, retry_count=retry_count) + await TPLinkIotProtocol("127.0.0.1").query({}, retry_count=retry_count) assert conn.call_count == retry_count + 1 async def test_protocol_no_retry_on_connection_error(mocker): conn = mocker.patch.object( - TPLinkKlap, + TPLinkKlapTransport, "client_post", side_effect=httpx.ConnectError("foo"), ) with pytest.raises(SmartDeviceException): - await TPLinkKlap("127.0.0.1").query({}, retry_count=5) + await TPLinkIotProtocol("127.0.0.1").query({}, retry_count=5) assert conn.call_count == 1 async def test_protocol_retry_recoverable_error(mocker): conn = mocker.patch.object( - TPLinkKlap, + TPLinkKlapTransport, "client_post", side_effect=httpx.CloseError("foo"), ) with pytest.raises(SmartDeviceException): - await TPLinkKlap("127.0.0.1").query({}, retry_count=5) + await TPLinkIotProtocol("127.0.0.1").query({}, retry_count=5) assert conn.call_count == 6 @@ -70,14 +71,14 @@ def _fail_one_less_than_retry_count(*_, **__): return 200, encrypted seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = TPLinkKlap("127.0.0.1") - protocol.handshake_done = True - protocol.session_expire_at = time.time() + 86400 - protocol.encryption_session = encryption_session + protocol = TPLinkIotProtocol("127.0.0.1") + protocol.transport.handshake_done = True + protocol.transport.session_expire_at = time.time() + 86400 + protocol.transport.encryption_session = encryption_session mocker.patch.object( - TPLinkKlap, "client_post", side_effect=_fail_one_less_than_retry_count + TPLinkKlapTransport, "client_post", side_effect=_fail_one_less_than_retry_count ) response = await protocol.query({}, retry_count=retry_count) @@ -96,14 +97,16 @@ def _return_encrypted(*_, **__): return 200, encrypted seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = TPLinkKlap("127.0.0.1") + protocol = TPLinkIotProtocol("127.0.0.1") - protocol.handshake_done = True - protocol.session_expire_at = time.time() + 86400 - protocol.encryption_session = encryption_session - mocker.patch.object(TPLinkKlap, "client_post", side_effect=_return_encrypted) + protocol.transport.handshake_done = True + protocol.transport.session_expire_at = time.time() + 86400 + protocol.transport.encryption_session = encryption_session + mocker.patch.object( + TPLinkKlapTransport, "client_post", side_effect=_return_encrypted + ) response = await protocol.query({}) assert response == {"great": "success"} @@ -117,7 +120,7 @@ def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) @@ -129,7 +132,7 @@ def test_encrypt_unicode(): d = "{'snowman': '\u2603'}" seed = secrets.token_bytes(16) - auth_hash = TPLinkKlap.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) @@ -145,7 +148,10 @@ def test_encrypt_unicode(): (Credentials("foo", "bar"), does_not_raise()), (Credentials("", ""), does_not_raise()), ( - Credentials(TPLinkKlap.KASA_SETUP_EMAIL, TPLinkKlap.KASA_SETUP_PASSWORD), + Credentials( + TPLinkKlapTransport.KASA_SETUP_EMAIL, + TPLinkKlapTransport.KASA_SETUP_PASSWORD, + ), does_not_raise(), ), ( @@ -167,21 +173,21 @@ async def _return_handshake1_response(url, params=None, data=None, *_, **__): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(device_credentials) + device_auth_hash = TPLinkKlapTransport.generate_auth_hash(device_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake1_response ) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) + protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) - protocol.http_client = httpx.AsyncClient() + protocol.transport.http_client = httpx.AsyncClient() with expectation: ( local_seed, device_remote_seed, auth_hash, - ) = await protocol.perform_handshake1() + ) = await protocol.transport.perform_handshake1() assert local_seed == client_seed assert device_remote_seed == server_seed @@ -204,23 +210,23 @@ async def _return_handshake_response(url, params=None, data=None, *_, **__): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) + device_auth_hash = TPLinkKlapTransport.generate_auth_hash(client_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake_response ) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) - protocol.http_client = httpx.AsyncClient() + protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) + protocol.transport.http_client = httpx.AsyncClient() response_status = 200 - await protocol.perform_handshake() - assert protocol.handshake_done is True + await protocol.transport.perform_handshake() + assert protocol.transport.handshake_done is True response_status = 403 with pytest.raises(AuthenticationException): - await protocol.perform_handshake() - assert protocol.handshake_done is False + await protocol.transport.perform_handshake() + assert protocol.transport.handshake_done is False await protocol.close() @@ -237,9 +243,9 @@ async def _return_response(url, params=None, data=None, *_, **__): return _mock_response(200, b"") elif url == "http://127.0.0.1/app/request": encryption_session = KlapEncryptionSession( - protocol.encryption_session.local_seed, - protocol.encryption_session.remote_seed, - protocol.encryption_session.user_hash, + protocol.transport.encryption_session.local_seed, + protocol.transport.encryption_session.remote_seed, + protocol.transport.encryption_session.user_hash, ) seq = params.get("seq") encryption_session._seq = seq - 1 @@ -252,11 +258,11 @@ async def _return_response(url, params=None, data=None, *_, **__): seq = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) + device_auth_hash = TPLinkKlapTransport.generate_auth_hash(client_credentials) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) + protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) for _ in range(10): resp = await protocol.query({}) @@ -296,11 +302,11 @@ async def _return_response(url, params=None, data=None, *_, **__): server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlap.generate_auth_hash(client_credentials) + device_auth_hash = TPLinkKlapTransport.generate_auth_hash(client_credentials) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = TPLinkKlap("127.0.0.1", credentials=client_credentials) + protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) with expectation: await protocol.query({}) From 30effeca6f926e9f4c169731cb06148beace881c Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Fri, 1 Dec 2023 15:12:26 +0000 Subject: [PATCH 2/4] Add tests and some review changes --- devtools/dump_devinfo.py | 10 +- kasa/__init__.py | 2 + kasa/aestransport.py | 8 +- kasa/device_factory.py | 46 ++++- kasa/discover.py | 53 +++--- kasa/smartdevice.py | 2 +- kasa/tests/conftest.py | 147 ++++++++++++-- kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json | 28 --- kasa/tests/fixtures/P110.smart_1.0_1.3.0.json | 180 ++++++++++++++++++ kasa/tests/newfakes.py | 37 +++- kasa/tests/test_cli.py | 35 ++-- kasa/tests/test_device_factory.py | 97 +++++++--- kasa/tests/test_discovery.py | 94 ++++----- kasa/tests/test_klapprotocol.py | 76 +++++--- kasa/tests/test_plug.py | 13 +- kasa/tests/test_smartdevice.py | 23 ++- 16 files changed, 638 insertions(+), 213 deletions(-) delete mode 100644 kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json create mode 100644 kasa/tests/fixtures/P110.smart_1.0_1.3.0.json diff --git a/devtools/dump_devinfo.py b/devtools/dump_devinfo.py index 76d750dcf..777ee1050 100644 --- a/devtools/dump_devinfo.py +++ b/devtools/dump_devinfo.py @@ -189,12 +189,16 @@ async def get_legacy_fixture(device): async def get_smart_fixture(device: SmartDevice): """Get fixture for new TAPO style protocol.""" items = [ + Call(module="component_nego", method="component_nego"), + Call(module="device_info", method="get_device_info"), + Call(module="device_usage", method="get_device_usage"), + Call(module="device_time", method="get_device_time"), + Call(module="energy_usage", method="get_energy_usage"), + Call(module="current_power", method="get_current_power"), Call( module="child_device_component_list", method="get_child_device_component_list", ), - Call(module="device_info", method="get_device_info"), - Call(module="device_usage", method="get_device_usage"), ] successes = [] @@ -246,7 +250,7 @@ async def get_smart_fixture(device: SmartDevice): model = final["get_device_info"]["model"] sw_version = sw_version.split(" ", maxsplit=1)[0] - return f"{model}_{hw_version}_{sw_version}.json", final + return f"{model}.smart_{hw_version}_{sw_version}.json", final if __name__ == "__main__": diff --git a/kasa/__init__.py b/kasa/__init__.py index 96576062e..cb7b18f74 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -28,6 +28,7 @@ from kasa.smartdimmer import SmartDimmer from kasa.smartlightstrip import SmartLightStrip from kasa.smartplug import SmartPlug +from kasa.smartprotocol import TPLinkSmartProtocol from kasa.smartstrip import SmartStrip __version__ = version("python-kasa") @@ -38,6 +39,7 @@ "TPLinkSmartHomeProtocol", "TPLinkProtocol", "TPLinkIotProtocol", + "TPLinkSmartProtocol", "SmartBulb", "SmartBulbPreset", "TurnOnBehaviors", diff --git a/kasa/aestransport.py b/kasa/aestransport.py index b90dccd83..3e725d997 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -26,7 +26,7 @@ _LOGGER = logging.getLogger(__name__) -def _sha1_hex(payload: bytes) -> str: +def _sha1(payload: bytes) -> str: sha1_algo = hashlib.sha1() # noqa: S324 sha1_algo.update(payload) return sha1_algo.hexdigest() @@ -84,14 +84,14 @@ def hash_credentials(self, credentials, try_login_version2): """Hash the credentials.""" if try_login_version2: un = base64.b64encode( - _sha1_hex(credentials.username.encode()).encode() + _sha1(credentials.username.encode()).encode() ).decode() pw = base64.b64encode( - _sha1_hex(credentials.password.encode()).encode() + _sha1(credentials.password.encode()).encode() ).decode() else: un = base64.b64encode( - _sha1_hex(credentials.username.encode()).encode() + _sha1(credentials.username.encode()).encode() ).decode() pw = base64.b64encode(credentials.password.encode()).decode() return un, pw diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 9122003c8..cb1649f12 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -2,17 +2,21 @@ import logging import time -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Tuple, Type +from .aestransport import TPLinkAesTransport from .credentials import Credentials from .device_type import DeviceType from .exceptions import UnsupportedDeviceException -from .protocol import TPLinkProtocol +from .iotprotocol import TPLinkIotProtocol +from .klaptransport import TPLinkKlapTransport, TPlinkKlapTransportV2 +from .protocol import TPLinkProtocol, TPLinkTransport from .smartbulb import SmartBulb from .smartdevice import SmartDevice, SmartDeviceException from .smartdimmer import SmartDimmer from .smartlightstrip import SmartLightStrip from .smartplug import SmartPlug +from .smartprotocol import TPLinkSmartProtocol from .smartstrip import SmartStrip from .tapo.tapoplug import TapoPlug @@ -87,7 +91,7 @@ async def connect( if protocol_class is not None: unknown_dev.protocol = protocol_class(host, credentials=credentials) await unknown_dev.update() - device_class = get_device_class_from_info(unknown_dev.internal_state) + device_class = get_device_class_from_sys_info(unknown_dev.internal_state) dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) # Reuse the connection from the unknown device # so we don't have to reconnect @@ -104,7 +108,7 @@ async def connect( return dev -def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]: +def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]: """Find SmartDevice subclass for device described by passed data.""" if "system" not in info or "get_sysinfo" not in info["system"]: raise SmartDeviceException("No 'system' or 'get_sysinfo' in response") @@ -129,3 +133,37 @@ def get_device_class_from_info(info: Dict[str, Any]) -> Type[SmartDevice]: return SmartBulb raise UnsupportedDeviceException("Unknown device type: %s" % type_) + + +def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]: + """Return the device class from the type name.""" + supported_device_types: dict[str, Type[SmartDevice]] = { + "SMART.TAPOPLUG": TapoPlug, + "SMART.KASAPLUG": TapoPlug, + "IOT.SMARTPLUGSWITCH": SmartPlug, + } + return supported_device_types.get(device_type) + + +def get_protocol_from_connection_name( + connection_name: str, host: str, credentials: Optional[Credentials] = None +) -> Optional[TPLinkProtocol]: + """Return the protocol from the connection name.""" + supported_device_protocols: dict[ + str, Tuple[Type[TPLinkProtocol], Type[TPLinkTransport]] + ] = { + "IOT.KLAP": (TPLinkIotProtocol, TPLinkKlapTransport), + "SMART.AES": (TPLinkSmartProtocol, TPLinkAesTransport), + "SMART.KLAP": (TPLinkSmartProtocol, TPlinkKlapTransportV2), + } + if connection_name not in supported_device_protocols: + return None + + protocol_transport_tuple = supported_device_protocols.get(connection_name) + transport: TPLinkTransport = protocol_transport_tuple[1]( # type: ignore + host, credentials=credentials + ) + protocol: TPLinkProtocol = protocol_transport_tuple[0]( # type: ignore + host, credentials=credentials, transport=transport + ) + return protocol diff --git a/kasa/discover.py b/kasa/discover.py index d50445c9e..2038369b4 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -4,7 +4,7 @@ import ipaddress import logging import socket -from typing import Awaitable, Callable, Dict, Optional, Set, Tuple, Type, cast +from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout @@ -17,17 +17,16 @@ from kasa.credentials import Credentials from kasa.exceptions import UnsupportedDeviceException -from kasa.iotprotocol import TPLinkIotProtocol from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads -from kasa.klaptransport import TPLinkKlapTransport, TPlinkKlapTransportV2 -from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol, TPLinkTransport +from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartdevice import SmartDevice, SmartDeviceException -from kasa.smartplug import SmartPlug -from kasa.smartprotocol import TPLinkAesTransport, TPLinkSmartProtocol -from kasa.tapo.tapoplug import TapoPlug -from .device_factory import get_device_class_from_info +from .device_factory import ( + get_device_class_from_sys_info, + get_device_class_from_type_name, + get_protocol_from_connection_name, +) _LOGGER = logging.getLogger(__name__) @@ -349,7 +348,16 @@ async def discover_single( @staticmethod def _get_device_class(info: dict) -> Type[SmartDevice]: """Find SmartDevice subclass for device described by passed data.""" - return get_device_class_from_info(info) + if "result" in info: + discovery_result = DiscoveryResult(**info["result"]) + dev_class = get_device_class_from_type_name(discovery_result.device_type) + if not dev_class: + raise UnsupportedDeviceException( + "Unknown device type: %s" % discovery_result.device_type + ) + return dev_class + else: + return get_device_class_from_sys_info(info) @staticmethod def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: @@ -385,28 +393,16 @@ def _get_device_instance( encrypt_type_ = ( f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}" ) - device_class = None - - supported_device_types: dict[str, Type[SmartDevice]] = { - "SMART.TAPOPLUG": TapoPlug, - "SMART.KASAPLUG": TapoPlug, - "IOT.SMARTPLUGSWITCH": SmartPlug, - } - supported_device_protocols: dict[ - str, Tuple[Type[TPLinkProtocol], Type[TPLinkTransport]] - ] = { - "IOT.KLAP": (TPLinkIotProtocol, TPLinkKlapTransport), - "SMART.AES": (TPLinkSmartProtocol, TPLinkAesTransport), - "SMART.KLAP": (TPLinkSmartProtocol, TPlinkKlapTransportV2), - } - - if (device_class := supported_device_types.get(type_)) is None: + + if (device_class := get_device_class_from_type_name(type_)) is None: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceException( f"Unsupported device {ip} of type {type_}: {info}" ) if ( - protocol_transport_tuple := supported_device_protocols.get(encrypt_type_) + protocol := get_protocol_from_connection_name( + encrypt_type_, ip, credentials=credentials + ) ) is None: _LOGGER.warning("Got unsupported device type: %s", encrypt_type_) raise UnsupportedDeviceException( @@ -415,10 +411,7 @@ def _get_device_instance( _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) device = device_class(ip, port=port, credentials=credentials) - transport = protocol_transport_tuple[1](ip, credentials=credentials) - device.protocol = protocol_transport_tuple[0]( - ip, credentials=credentials, transport=transport - ) + device.protocol = protocol device.update_from_discover_info(discovery_result.get_dict()) return device diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index af6a2c7f0..342d1c4a6 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -365,6 +365,7 @@ async def _modular_update(self, req: dict) -> None: def update_from_discover_info(self, info: Dict[str, Any]) -> None: """Update state from info from the discover call.""" + self._discovery_info = info if "system" in info and (sys_info := info["system"].get("get_sysinfo")): self._last_update = info self._set_sys_info(sys_info) @@ -372,7 +373,6 @@ def update_from_discover_info(self, info: Dict[str, Any]) -> None: # This allows setting of some info properties directly # from partial discovery info that will then be found # by the requires_update decorator - self._discovery_info = info self._set_sys_info(info) def _set_sys_info(self, sys_info: Dict[str, Any]) -> None: diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 2b2adc7dd..387cc116d 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -2,23 +2,28 @@ import glob import json import os +from dataclasses import dataclass +from json import dumps as json_dumps from os.path import basename from pathlib import Path, PurePath -from typing import Dict +from typing import Dict, Optional from unittest.mock import MagicMock import pytest # type: ignore # see https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + Credentials, Discover, SmartBulb, SmartDimmer, SmartLightStrip, SmartPlug, SmartStrip, + TPLinkSmartHomeProtocol, ) +from kasa.tapo import TapoPlug -from .newfakes import FakeTransportProtocol +from .newfakes import FakeSmartProtocol, FakeTransportProtocol SUPPORTED_DEVICES = glob.glob( os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json" @@ -55,22 +60,31 @@ "KP401", "KS200M", } + STRIPS = {"HS107", "HS300", "KP303", "KP200", "KP400", "EP40"} DIMMERS = {"ES20M", "HS220", "KS220M", "KS230", "KP405"} DIMMABLE = {*BULBS, *DIMMERS} WITH_EMETER = {"HS110", "HS300", "KP115", "KP125", *BULBS} -ALL_DEVICES = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) +ALL_DEVICES_IOT = BULBS.union(PLUGS).union(STRIPS).union(DIMMERS) + +PLUGS_SMART = {"P110"} +ALL_DEVICES_SMART = PLUGS_SMART + +ALL_DEVICES = ALL_DEVICES_IOT.union(ALL_DEVICES_SMART) IP_MODEL_CACHE: Dict[str, str] = {} -def filter_model(desc, filter): +def filter_model(desc, filter, is_smart_protocol=False): filtered = list() for dev in SUPPORTED_DEVICES: for filt in filter: - if filt in basename(dev): + model_name = filt + if is_smart_protocol: + model_name = model_name + ".smart" + if model_name in basename(dev).split("_")[0]: filtered.append(dev) filtered_basenames = [basename(f) for f in filtered] @@ -78,14 +92,14 @@ def filter_model(desc, filter): return filtered -def parametrize(desc, devices, ids=None): +def parametrize(desc, devices, ids=None, is_smart_protocol=False): return pytest.mark.parametrize( "dev", filter_model(desc, devices), indirect=True, ids=ids ) has_emeter = parametrize("has emeter", WITH_EMETER) -no_emeter = parametrize("no emeter", ALL_DEVICES - WITH_EMETER) +no_emeter = parametrize("no emeter", ALL_DEVICES_IOT - WITH_EMETER) bulb = parametrize("bulbs", BULBS, ids=basename) plug = parametrize("plugs", PLUGS, ids=basename) @@ -101,6 +115,55 @@ def parametrize(desc, devices, ids=None): color_bulb = parametrize("color bulbs", COLOR_BULBS) non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS) +plug_smart = parametrize( + "plug devices smart", PLUGS_SMART, ids=basename, is_smart_protocol=True +) +device_smart = parametrize( + "devices smart", ALL_DEVICES_SMART, ids=basename, is_smart_protocol=True +) +device_iot = parametrize( + "devices iot", ALL_DEVICES_IOT, ids=basename, is_smart_protocol=False +) + + +def get_fixture_data(): + """Return raw discovery file contents as JSON. Used for discovery tests.""" + fixture_data = {} + for file in SUPPORTED_DEVICES: + p = Path(file) + if not p.is_absolute(): + p = Path(__file__).parent / "fixtures" / file + + with open(p) as f: + fixture_data[basename(p)] = json.load(f) + return fixture_data + + +FIXTURE_DATA = get_fixture_data() + + +def filter_fixtures(desc, root_filter): + filtered = {} + for key, val in FIXTURE_DATA.items(): + if root_filter in val: + filtered[key] = val + + print(f"{desc}: {filtered.keys()}") + return filtered + + +def parametrize_discovery(desc, root_key): + filtered_fixtures = filter_fixtures(desc, root_key) + return pytest.mark.parametrize( + "discovery_data", + filtered_fixtures.values(), + indirect=True, + ids=filtered_fixtures.keys(), + ) + + +new_discovery = parametrize_discovery("new discovery", "discovery_result") + def check_categories(): """Check that every fixture file is categorized.""" @@ -110,6 +173,7 @@ def check_categories(): + plug.args[1] + bulb.args[1] + lightstrip.args[1] + + plug_smart.args[1] ) diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures) if diff: @@ -118,7 +182,7 @@ def check_categories(): "No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)" % file ) - raise Exception("Missing category for %s" % diff) + raise Exception(f"Missing category for {diff}") check_categories() @@ -156,6 +220,10 @@ def device_for_file(model): if d in model: return SmartDimmer + for d in PLUGS_SMART: + if d + ".smart" in model: + return TapoPlug + raise Exception("Unable to find type for %s", model) @@ -185,7 +253,11 @@ def load_file(): model = basename(file) d = device_for_file(model)(host="127.0.0.123") - d.protocol = FakeTransportProtocol(sysinfo) + if ".smart" in model.split("_")[0]: + d.protocol = FakeSmartProtocol(sysinfo) + d.credentials = Credentials("", "") + else: + d.protocol = FakeTransportProtocol(sysinfo) await _update_and_close(d) return d @@ -213,16 +285,59 @@ async def dev(request): return await get_device_for_file(file) -@pytest.fixture(params=SUPPORTED_DEVICES, scope="session") +@pytest.fixture +def discovery_mock(discovery_data, mocker): + @dataclass + class _DiscoveryMock: + ip: str + default_port: int + discovery_data: dict + port_override: Optional[int] = None + + if "result" in discovery_data: + datagram = ( + b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + + json_dumps(discovery_data).encode() + ) + dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data) + else: + datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] + dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data) + + def mock_discover(self): + port = ( + dm.port_override + if dm.port_override and dm.default_port != 20002 + else dm.default_port + ) + self.datagram_received( + datagram, + (dm.ip, port), + ) + + mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover) + mocker.patch( + "socket.getaddrinfo", + side_effect=lambda *_, **__: [(None, None, None, None, (dm.ip, 0))], + ) + yield dm + + +@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session") def discovery_data(request): """Return raw discovery file contents as JSON. Used for discovery tests.""" - file = request.param - p = Path(file) - if not p.is_absolute(): - p = Path(__file__).parent / "fixtures" / file + fixture_data = request.param + if "discovery_result" in fixture_data: + return {"result": fixture_data["discovery_result"]} + else: + return {"system": {"get_sysinfo": fixture_data["system"]["get_sysinfo"]}} + - with open(p) as f: - return json.load(f) +@pytest.fixture(params=FIXTURE_DATA.values(), ids=FIXTURE_DATA.keys(), scope="session") +def all_fixture_data(request): + """Return raw fixture file contents as JSON. Used for discovery tests.""" + fixture_data = request.param + return fixture_data def pytest_addoption(parser): diff --git a/kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json b/kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json deleted file mode 100644 index 787f367e2..000000000 --- a/kasa/tests/fixtures/HS100(UK)_1.0_1.2.6.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "system": { - "get_sysinfo": { - "active_mode": "schedule", - "alias": "Living Room Lamp", - "dev_name": "Wi-Fi Smart Plug", - "deviceId": "0000000000000000000000000000000000000000", - "err_code": 0, - "feature": "TIM", - "fwId": "00000000000000000000000000000000", - "hwId": "00000000000000000000000000000000", - "hw_ver": "1.0", - "icon_hash": "", - "latitude": 0, - "led_off": 0, - "longitude": 0, - "mac": "00:00:00:00:00:00", - "model": "HS100(UK)", - "oemId": "00000000000000000000000000000000", - "on_time": 4102, - "relay_state": 1, - "rssi": -58, - "sw_ver": "1.2.6 Build 200727 Rel.120236", - "type": "IOT.SMARTPLUGSWITCH", - "updating": 0 - } - } -} diff --git a/kasa/tests/fixtures/P110.smart_1.0_1.3.0.json b/kasa/tests/fixtures/P110.smart_1.0_1.3.0.json new file mode 100644 index 000000000..99fd3f133 --- /dev/null +++ b/kasa/tests/fixtures/P110.smart_1.0_1.3.0.json @@ -0,0 +1,180 @@ +{ + "component_nego": { + "component_list": [ + { + "id": "device", + "ver_code": 2 + }, + { + "id": "firmware", + "ver_code": 2 + }, + { + "id": "quick_setup", + "ver_code": 3 + }, + { + "id": "time", + "ver_code": 1 + }, + { + "id": "wireless", + "ver_code": 1 + }, + { + "id": "schedule", + "ver_code": 2 + }, + { + "id": "countdown", + "ver_code": 2 + }, + { + "id": "antitheft", + "ver_code": 1 + }, + { + "id": "account", + "ver_code": 1 + }, + { + "id": "synchronize", + "ver_code": 1 + }, + { + "id": "sunrise_sunset", + "ver_code": 1 + }, + { + "id": "led", + "ver_code": 1 + }, + { + "id": "cloud_connect", + "ver_code": 1 + }, + { + "id": "iot_cloud", + "ver_code": 1 + }, + { + "id": "device_local_time", + "ver_code": 1 + }, + { + "id": "default_states", + "ver_code": 1 + }, + { + "id": "auto_off", + "ver_code": 2 + }, + { + "id": "localSmart", + "ver_code": 1 + }, + { + "id": "energy_monitoring", + "ver_code": 2 + }, + { + "id": "power_protection", + "ver_code": 1 + }, + { + "id": "current_protection", + "ver_code": 1 + } + ] + }, + "discovery_result": { + "device_id": "00000000000000000000000000000000", + "device_model": "P110(UK)", + "device_type": "SMART.TAPOPLUG", + "factory_default": false, + "ip": "127.0.0.123", + "is_support_iot_cloud": true, + "mac": "00-00-00-00-00-00", + "mgt_encrypt_schm": { + "encrypt_type": "KLAP", + "http_port": 80, + "is_support_https": false, + "lv": 2 + }, + "obd_src": "tplink", + "owner": "00000000000000000000000000000000" + }, + "get_current_power": { + "current_power": 0 + }, + "get_device_info": { + "auto_off_remain_time": 0, + "auto_off_status": "off", + "avatar": "plug", + "default_states": { + "state": {}, + "type": "last_states" + }, + "device_id": "0000000000000000000000000000000000000000", + "device_on": true, + "fw_id": "00000000000000000000000000000000", + "fw_ver": "1.3.0 Build 230905 Rel.152200", + "has_set_location_info": true, + "hw_id": "00000000000000000000000000000000", + "hw_ver": "1.0", + "ip": "127.0.0.123", + "lang": "en_US", + "latitude": 0, + "longitude": 0, + "mac": "00-00-00-00-00-00", + "model": "P110", + "nickname": "VGFwaSBTbWFydCBQbHVnIDE=", + "oem_id": "00000000000000000000000000000000", + "on_time": 119335, + "overcurrent_status": "normal", + "overheated": false, + "power_protection_status": "normal", + "region": "Europe/London", + "rssi": -57, + "signal_level": 2, + "specs": "", + "ssid": "IyNNQVNLRUROQU1FIyM=", + "time_diff": 0, + "type": "SMART.TAPOPLUG" + }, + "get_device_time": { + "region": "Europe/London", + "time_diff": 0, + "timestamp": 1701370224 + }, + "get_device_usage": { + "power_usage": { + "past30": 75, + "past7": 69, + "today": 0 + }, + "saved_power": { + "past30": 2029, + "past7": 1964, + "today": 1130 + }, + "time_usage": { + "past30": 2104, + "past7": 2033, + "today": 1130 + } + }, + "get_energy_usage": { + "current_power": 0, + "electricity_charge": [ + 0, + 0, + 0 + ], + "local_time": "2023-11-30 18:50:24", + "month_energy": 75, + "month_runtime": 2104, + "today_energy": 0, + "today_runtime": 1130 + } +} diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index ee679cae8..ceb2f5cbb 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -1,6 +1,7 @@ import copy import logging import re +from json import loads as json_loads from voluptuous import ( REMOVE_EXTRA, @@ -13,7 +14,8 @@ Schema, ) -from ..protocol import TPLinkSmartHomeProtocol +from ..protocol import TPLinkSmartHomeProtocol, TPLinkTransport +from ..smartprotocol import TPLinkSmartProtocol _LOGGER = logging.getLogger(__name__) @@ -285,6 +287,39 @@ def success(res): } +class FakeSmartProtocol(TPLinkSmartProtocol): + def __init__(self, info): + super().__init__("127.0.0.123", transport=FakeSmartTransport(info)) + + +class FakeSmartTransport(TPLinkTransport): + def __init__(self, info): + self.info = info + + def needs_handshake(self) -> bool: + return False + + def needs_login(self) -> bool: + return False + + async def login(self, request: str) -> None: + pass + + async def handshake(self) -> None: + pass + + async def send(self, request: str): + request_dict = json_loads(request) + method = request_dict["method"] + if method == "component_nego" or method[:4] == "get_": + return self.info[method] + elif method[:4] == "set_": + pass + + async def close(self) -> None: + pass + + class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): self.discovery_data = info diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index f590808f8..514a93ab0 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -6,12 +6,15 @@ from kasa import SmartDevice, TPLinkSmartHomeProtocol from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle +from kasa.device_factory import DEVICE_TYPE_TO_CLASS from kasa.discover import Discover +from kasa.smartprotocol import TPLinkSmartProtocol -from .conftest import handle_turn_on, turn_on -from .newfakes import FakeTransportProtocol +from .conftest import device_iot, handle_turn_on, new_discovery, turn_on +from .newfakes import FakeSmartProtocol, FakeTransportProtocol +@device_iot async def test_sysinfo(dev): runner = CliRunner() res = await runner.invoke(sysinfo, obj=dev) @@ -19,6 +22,7 @@ async def test_sysinfo(dev): assert dev.alias in res.output +@device_iot @turn_on async def test_state(dev, turn_on): await handle_turn_on(dev, turn_on) @@ -32,6 +36,7 @@ async def test_state(dev, turn_on): assert "Device state: False" in res.output +@device_iot @turn_on async def test_toggle(dev, turn_on, mocker): await handle_turn_on(dev, turn_on) @@ -44,6 +49,7 @@ async def test_toggle(dev, turn_on, mocker): assert dev.is_on +@device_iot async def test_alias(dev): runner = CliRunner() @@ -62,6 +68,7 @@ async def test_alias(dev): await dev.set_alias(old_alias) +@device_iot async def test_raw_command(dev): runner = CliRunner() res = await runner.invoke(raw_command, ["system", "get_sysinfo"], obj=dev) @@ -74,6 +81,7 @@ async def test_raw_command(dev): assert "Usage" in res.output +@device_iot async def test_emeter(dev: SmartDevice, mocker): runner = CliRunner() @@ -99,6 +107,7 @@ async def test_emeter(dev: SmartDevice, mocker): daily.assert_called_with(year=1900, month=12) +@device_iot async def test_brightness(dev): runner = CliRunner() res = await runner.invoke(brightness, obj=dev) @@ -116,6 +125,7 @@ async def test_brightness(dev): assert "Brightness: 12" in res.output +@device_iot async def test_json_output(dev: SmartDevice, mocker): """Test that the json output produces correct output.""" mocker.patch("kasa.Discover.discover", return_value=[dev]) @@ -125,13 +135,9 @@ async def test_json_output(dev: SmartDevice, mocker): assert json.loads(res.output) == dev.internal_state -async def test_credentials(discovery_data: dict, mocker): +@new_discovery +async def test_credentials(discovery_mock, mocker): """Test credentials are passed correctly from cli to device.""" - # As this is testing the device constructor need to explicitly wire in - # the FakeTransportProtocol - ftp = FakeTransportProtocol(discovery_data) - mocker.patch.object(TPLinkSmartHomeProtocol, "query", ftp.query) - # Patch state to echo username and password pass_dev = click.make_pass_decorator(SmartDevice) @@ -143,18 +149,15 @@ async def _state(dev: SmartDevice): ) mocker.patch("kasa.cli.state", new=_state) - cli_device_type = Discover._get_device_class(discovery_data)( - "any" - ).device_type.value + for subclass in DEVICE_TYPE_TO_CLASS.values(): + mocker.patch.object(subclass, "update") runner = CliRunner() res = await runner.invoke( cli, [ "--host", - "127.0.0.1", - "--type", - cli_device_type, + "127.0.0.123", "--username", "foo", "--password", @@ -162,9 +165,11 @@ async def _state(dev: SmartDevice): ], ) assert res.exit_code == 0 - assert res.output == "Username:foo Password:bar\n" + + assert "Username:foo Password:bar\n" in res.output +@device_iot async def test_without_device_type(discovery_data: dict, dev, mocker): """Test connecting without the device type.""" runner = CliRunner() diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 991e75bdc..4009f247a 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -5,7 +5,9 @@ import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + Credentials, DeviceType, + Discover, SmartBulb, SmartDevice, SmartDeviceException, @@ -13,7 +15,12 @@ SmartLightStrip, SmartPlug, ) -from kasa.device_factory import connect +from kasa.device_factory import ( + DEVICE_TYPE_TO_CLASS, + connect, + get_protocol_from_connection_name, +) +from kasa.discover import DiscoveryResult from kasa.iotprotocol import TPLinkIotProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol @@ -22,11 +29,15 @@ async def test_connect(discovery_data: dict, mocker, custom_port): """Make sure that connect returns an initialized SmartDevice instance.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port) - assert issubclass(dev.__class__, SmartDevice) - assert dev.port == custom_port or dev.port == 9999 + if "result" in discovery_data: + with pytest.raises(SmartDeviceException): + dev = await connect(host, port=custom_port) + else: + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + dev = await connect(host, port=custom_port) + assert issubclass(dev.__class__, SmartDevice) + assert dev.port == custom_port or dev.port == 9999 @pytest.mark.parametrize("custom_port", [123, None]) @@ -49,11 +60,15 @@ async def test_connect_passed_device_type( ): """Make sure that connect with a passed device type.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port, device_type=device_type) - assert isinstance(dev, klass) - assert dev.port == custom_port or dev.port == 9999 + if "result" in discovery_data: + with pytest.raises(SmartDeviceException): + dev = await connect(host, port=custom_port) + else: + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + dev = await connect(host, port=custom_port, device_type=device_type) + assert isinstance(dev, klass) + assert dev.port == custom_port or dev.port == 9999 async def test_connect_query_fails(discovery_data: dict, mocker): @@ -70,32 +85,52 @@ async def test_connect_logs_connect_time( ): """Test that the connect time is logged when debug logging is enabled.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - logging.getLogger("kasa").setLevel(logging.DEBUG) - await connect(host) - assert "seconds to connect" in caplog.text + if "result" in discovery_data: + with pytest.raises(SmartDeviceException): + await connect(host) + else: + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) + logging.getLogger("kasa").setLevel(logging.DEBUG) + await connect(host) + assert "seconds to connect" in caplog.text -@pytest.mark.parametrize("device_type", [DeviceType.Plug, None]) -@pytest.mark.parametrize( - ("protocol_in", "protocol_result"), - ( - (None, TPLinkSmartHomeProtocol), - (TPLinkIotProtocol, TPLinkIotProtocol), - (TPLinkSmartHomeProtocol, TPLinkSmartHomeProtocol), - ), -) async def test_connect_pass_protocol( - discovery_data: dict, + all_fixture_data: dict, mocker, - device_type: DeviceType, - protocol_in: Type[TPLinkProtocol], - protocol_result: Type[TPLinkProtocol], ): """Test that if the protocol is passed in it's gets set correctly.""" + if "discovery_result" in all_fixture_data: + discovery_info = {"result": all_fixture_data["discovery_result"]} + device_class = Discover._get_device_class(discovery_info) + else: + device_class = Discover._get_device_class(all_fixture_data) + + device_type = list(DEVICE_TYPE_TO_CLASS.keys())[ + list(DEVICE_TYPE_TO_CLASS.values()).index(device_class) + ] host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - mocker.patch("kasa.TPLinkIotProtocol.query", return_value=discovery_data) - - dev = await connect(host, device_type=device_type, protocol_class=protocol_in) - assert isinstance(dev.protocol, protocol_result) + if "discovery_result" in all_fixture_data: + mocker.patch("kasa.TPLinkIotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.TPLinkSmartProtocol.query", return_value=all_fixture_data) + + dr = DiscoveryResult(**discovery_info["result"]) + connection_name = ( + dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type + ) + protocol_class = get_protocol_from_connection_name( + connection_name, host + ).__class__ + else: + mocker.patch( + "kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data + ) + protocol_class = TPLinkSmartHomeProtocol + + dev = await connect( + host, + device_type=device_type, + protocol_class=protocol_class, + credentials=Credentials("", ""), + ) + assert isinstance(dev.protocol, protocol_class) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 626afd180..ea97d94ad 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -17,6 +17,27 @@ from .conftest import bulb, dimmer, lightstrip, plug, strip +UNSUPPORTED = { + "result": { + "device_id": "xx", + "owner": "xx", + "device_type": "SMART.TAPOXMASTREE", + "device_model": "P110(EU)", + "ip": "127.0.0.1", + "mac": "48-22xxx", + "is_support_iot_cloud": True, + "obd_src": "tplink", + "factory_default": False, + "mgt_encrypt_schm": { + "is_support_https": False, + "encrypt_type": "AES", + "http_port": 80, + "lv": 2, + }, + }, + "error_code": 0, +} + @plug async def test_type_detection_plug(dev: SmartDevice): @@ -62,76 +83,40 @@ async def test_type_unknown(): @pytest.mark.parametrize("custom_port", [123, None]) -async def test_discover_single(discovery_data: dict, mocker, custom_port): +# @pytest.mark.parametrize("discovery_mock", [("127.0.0.1",123), ("127.0.0.1",None)], indirect=True) +async def test_discover_single(discovery_mock, custom_port, mocker): """Make sure that discover_single returns an initialized SmartDevice instance.""" host = "127.0.0.1" - info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} - query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info) - - def mock_discover(self): - self.datagram_received( - protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:], - (host, custom_port or 9999), - ) - - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) + discovery_mock.ip = host + discovery_mock.port_override = custom_port + update_mock = mocker.patch.object(SmartStrip, "update") x = await Discover.discover_single(host, port=custom_port) assert issubclass(x.__class__, SmartDevice) - assert x._sys_info is not None - assert x.port == custom_port or x.port == 9999 - assert (query_mock.call_count > 0) == isinstance(x, SmartStrip) + assert x._discovery_info is not None + assert x.port == custom_port or x.port == discovery_mock.default_port + assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) -async def test_discover_single_hostname(discovery_data: dict, mocker): +async def test_discover_single_hostname(discovery_mock, mocker): """Make sure that discover_single returns an initialized SmartDevice instance.""" host = "foobar" ip = "127.0.0.1" - info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} - query_mock = mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=info) - - def mock_discover(self): - self.datagram_received( - protocol.TPLinkSmartHomeProtocol.encrypt(json_dumps(info))[4:], - (ip, 9999), - ) - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) - mocker.patch("socket.getaddrinfo", return_value=[(None, None, None, None, (ip, 0))]) + discovery_mock.ip = ip + update_mock = mocker.patch.object(SmartStrip, "update") x = await Discover.discover_single(host) assert issubclass(x.__class__, SmartDevice) - assert x._sys_info is not None + assert x._discovery_info is not None assert x.host == host - assert (query_mock.call_count > 0) == isinstance(x, SmartStrip) + assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) with pytest.raises(SmartDeviceException): x = await Discover.discover_single(host) -UNSUPPORTED = { - "result": { - "device_id": "xx", - "owner": "xx", - "device_type": "SMART.TAPOXMASTREE", - "device_model": "P110(EU)", - "ip": "127.0.0.1", - "mac": "48-22xxx", - "is_support_iot_cloud": True, - "obd_src": "tplink", - "factory_default": False, - "mgt_encrypt_schm": { - "is_support_https": False, - "encrypt_type": "AES", - "http_port": 80, - "lv": 2, - }, - }, - "error_code": 0, -} - - async def test_discover_single_unsupported(mocker): """Make sure that discover_single handles unsupported devices correctly.""" host = "127.0.0.1" @@ -201,14 +186,17 @@ async def test_discover_send(mocker): async def test_discover_datagram_received(mocker, discovery_data): """Verify that datagram received fills discovered_devices.""" proto = _DiscoverProtocol() - info = {"system": {"get_sysinfo": discovery_data["system"]["get_sysinfo"]}} - mocker.patch("kasa.discover.json_loads", return_value=info) - mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "encrypt") + mocker.patch.object(protocol.TPLinkSmartHomeProtocol, "decrypt") addr = "127.0.0.1" - proto.datagram_received("", (addr, 9999)) + port = 20002 if "result" in discovery_data else 9999 + + mocker.patch("kasa.discover.json_loads", return_value=discovery_data) + proto.datagram_received("", (addr, port)) + addr2 = "127.0.0.2" + mocker.patch("kasa.discover.json_loads", return_value=UNSUPPORTED) proto.datagram_received("", (addr2, 20002)) # Check that device in discovered_devices is initialized correctly diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 2d3eb6c2a..8dd49013b 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -10,10 +10,14 @@ import httpx import pytest +from ..aestransport import TPLinkAesTransport from ..credentials import Credentials from ..exceptions import AuthenticationException, SmartDeviceException from ..iotprotocol import TPLinkIotProtocol from ..klaptransport import KlapEncryptionSession, TPLinkKlapTransport, _sha256 +from ..smartprotocol import TPLinkSmartProtocol + +DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} class _mock_response: @@ -22,67 +26,89 @@ def __init__(self, status_code, content: bytes): self.content = content +@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) +@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_retries(mocker, retry_count): +async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): + host = "127.0.0.1" conn = mocker.patch.object( - TPLinkKlapTransport, "client_post", side_effect=Exception("dummy exception") + transport_class, "client_post", side_effect=Exception("dummy exception") ) with pytest.raises(SmartDeviceException): - await TPLinkIotProtocol("127.0.0.1").query({}, retry_count=retry_count) + await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=retry_count + ) assert conn.call_count == retry_count + 1 -async def test_protocol_no_retry_on_connection_error(mocker): +@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) +@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) +async def test_protocol_no_retry_on_connection_error( + mocker, protocol_class, transport_class +): + host = "127.0.0.1" conn = mocker.patch.object( - TPLinkKlapTransport, + transport_class, "client_post", side_effect=httpx.ConnectError("foo"), ) with pytest.raises(SmartDeviceException): - await TPLinkIotProtocol("127.0.0.1").query({}, retry_count=5) + await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=5 + ) assert conn.call_count == 1 -async def test_protocol_retry_recoverable_error(mocker): +@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) +@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) +async def test_protocol_retry_recoverable_error( + mocker, protocol_class, transport_class +): + host = "127.0.0.1" conn = mocker.patch.object( - TPLinkKlapTransport, + transport_class, "client_post", side_effect=httpx.CloseError("foo"), ) with pytest.raises(SmartDeviceException): - await TPLinkIotProtocol("127.0.0.1").query({}, retry_count=5) + await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=5 + ) assert conn.call_count == 6 +@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) +@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_reconnect(mocker, retry_count): +async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class): + host = "127.0.0.1" remaining = retry_count + mock_response = {"result": {"great": "success"}} def _fail_one_less_than_retry_count(*_, **__): - nonlocal remaining, encryption_session + nonlocal remaining remaining -= 1 if remaining: raise Exception("Simulated post failure") - # Do the encrypt just before returning the value so the incrementing sequence number is correct - encrypted, seq = encryption_session.encrypt('{"great":"success"}') - return 200, encrypted - seed = secrets.token_bytes(16) - auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) - encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = TPLinkIotProtocol("127.0.0.1") - protocol.transport.handshake_done = True - protocol.transport.session_expire_at = time.time() + 86400 - protocol.transport.encryption_session = encryption_session - mocker.patch.object( - TPLinkKlapTransport, "client_post", side_effect=_fail_one_less_than_retry_count + return mock_response + + mocker.patch.object(transport_class, "needs_handshake", return_value=False) + mocker.patch.object(transport_class, "needs_login", return_value=False) + send_mock = mocker.patch.object( + transport_class, + "send", + side_effect=_fail_one_less_than_retry_count, ) - response = await protocol.query({}, retry_count=retry_count) - assert response == {"great": "success"} + response = await protocol_class(host, transport=transport_class(host)).query( + DUMMY_QUERY, retry_count=retry_count + ) + assert "result" in response or "great" in response + assert send_mock.call_count == retry_count @pytest.mark.parametrize("log_level", [logging.WARNING, logging.DEBUG]) diff --git a/kasa/tests/test_plug.py b/kasa/tests/test_plug.py index e97043101..e9e1592f9 100644 --- a/kasa/tests/test_plug.py +++ b/kasa/tests/test_plug.py @@ -1,6 +1,6 @@ from kasa import DeviceType -from .conftest import plug +from .conftest import plug, plug_smart from .newfakes import PLUG_SCHEMA @@ -28,3 +28,14 @@ async def test_led(dev): assert dev.led await dev.set_led(original) + + +@plug_smart +async def test_plug_device_info(dev): + assert dev._info is not None + # PLUG_SCHEMA(dev.sys_info) + + assert dev.model is not None + + assert dev.device_type == DeviceType.Plug or dev.device_type == DeviceType.Strip + # assert dev.is_plug or dev.is_strip diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 85dc358df..33c9f4483 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -8,7 +8,7 @@ from kasa import Credentials, SmartDevice, SmartDeviceException from kasa.smartdevice import DeviceType -from .conftest import handle_turn_on, has_emeter, no_emeter, turn_on +from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol # List of all SmartXXX classes including the SmartDevice base class @@ -22,11 +22,13 @@ ] +@device_iot async def test_state_info(dev): assert isinstance(dev.state_information, dict) @pytest.mark.requires_dummy +@device_iot async def test_invalid_connection(dev): with patch.object( FakeTransportProtocol, "query", side_effect=SmartDeviceException @@ -58,12 +60,14 @@ async def test_initial_update_no_emeter(dev, mocker): assert spy.call_count == 2 +@device_iot async def test_query_helper(dev): with pytest.raises(SmartDeviceException): await dev._query_helper("test", "testcmd", {}) # TODO check for unwrapping? +@device_iot @turn_on async def test_state(dev, turn_on): await handle_turn_on(dev, turn_on) @@ -90,6 +94,7 @@ async def test_state(dev, turn_on): assert dev.is_off +@device_iot async def test_alias(dev): test_alias = "TEST1234" original = dev.alias @@ -104,6 +109,7 @@ async def test_alias(dev): assert dev.alias == original +@device_iot @turn_on async def test_on_since(dev, turn_on): await handle_turn_on(dev, turn_on) @@ -116,30 +122,37 @@ async def test_on_since(dev, turn_on): assert dev.on_since is None +@device_iot async def test_time(dev): assert isinstance(await dev.get_time(), datetime) +@device_iot async def test_timezone(dev): TZ_SCHEMA(await dev.get_timezone()) +@device_iot async def test_hw_info(dev): PLUG_SCHEMA(dev.hw_info) +@device_iot async def test_location(dev): PLUG_SCHEMA(dev.location) +@device_iot async def test_rssi(dev): PLUG_SCHEMA({"rssi": dev.rssi}) # wrapping for vol +@device_iot async def test_mac(dev): PLUG_SCHEMA({"mac": dev.mac}) # wrapping for val +@device_iot async def test_representation(dev): import re @@ -147,6 +160,7 @@ async def test_representation(dev): assert pattern.match(str(dev)) +@device_iot async def test_childrens(dev): """Make sure that children property is exposed by every device.""" if dev.is_strip: @@ -155,6 +169,7 @@ async def test_childrens(dev): assert len(dev.children) == 0 +@device_iot async def test_children(dev): """Make sure that children property is exposed by every device.""" if dev.is_strip: @@ -165,11 +180,13 @@ async def test_children(dev): assert dev.has_children is False +@device_iot async def test_internal_state(dev): """Make sure the internal state returns the last update results.""" assert dev.internal_state == dev._last_update +@device_iot async def test_features(dev): """Make sure features is always accessible.""" sysinfo = dev._last_update["system"]["get_sysinfo"] @@ -179,11 +196,13 @@ async def test_features(dev): assert dev.features == set() +@device_iot async def test_max_device_response_size(dev): """Make sure every device return has a set max response size.""" assert dev.max_device_response_size > 0 +@device_iot async def test_estimated_response_sizes(dev): """Make sure every module has an estimated response size set.""" for mod in dev.modules.values(): @@ -202,6 +221,7 @@ def test_device_class_ctors(device_class): assert dev.credentials == credentials +@device_iot async def test_modules_preserved(dev: SmartDevice): """Make modules that are not being updated are preserved between updates.""" dev._last_update["some_module_not_being_updated"] = "should_be_kept" @@ -237,6 +257,7 @@ async def test_create_thin_wrapper(): ) +@device_iot async def test_modules_not_supported(dev: SmartDevice): """Test that unsupported modules do not break the device.""" for module in dev.modules.values(): From 90bf240c8a811e7c8c026f813e1bc2386af18d41 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Sun, 3 Dec 2023 18:21:10 +0000 Subject: [PATCH 3/4] Update following review --- kasa/__init__.py | 8 +- kasa/aestransport.py | 4 +- kasa/device_factory.py | 26 ++-- kasa/iotprotocol.py | 16 +- kasa/klaptransport.py | 6 +- kasa/protocol.py | 4 +- kasa/smartprotocol.py | 14 +- kasa/tapo/tapodevice.py | 6 +- kasa/tests/conftest.py | 138 +++++++++++------- .../P110_1.0_1.3.0.json} | 0 kasa/tests/newfakes.py | 8 +- kasa/tests/test_cli.py | 2 +- kasa/tests/test_device_factory.py | 6 +- kasa/tests/test_klapprotocol.py | 56 ++++--- kasa/tests/test_readme_examples.py | 14 +- 15 files changed, 161 insertions(+), 147 deletions(-) rename kasa/tests/fixtures/{P110.smart_1.0_1.3.0.json => smart/P110_1.0_1.3.0.json} (100%) diff --git a/kasa/__init__.py b/kasa/__init__.py index cb7b18f74..7de394c11 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -21,14 +21,14 @@ SmartDeviceException, UnsupportedDeviceException, ) -from kasa.iotprotocol import TPLinkIotProtocol +from kasa.iotprotocol import IotProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb, SmartBulbPreset, TurnOnBehavior, TurnOnBehaviors from kasa.smartdevice import DeviceType, SmartDevice from kasa.smartdimmer import SmartDimmer from kasa.smartlightstrip import SmartLightStrip from kasa.smartplug import SmartPlug -from kasa.smartprotocol import TPLinkSmartProtocol +from kasa.smartprotocol import SmartProtocol from kasa.smartstrip import SmartStrip __version__ = version("python-kasa") @@ -38,8 +38,8 @@ "Discover", "TPLinkSmartHomeProtocol", "TPLinkProtocol", - "TPLinkIotProtocol", - "TPLinkSmartProtocol", + "IotProtocol", + "SmartProtocol", "SmartBulb", "SmartBulbPreset", "TurnOnBehaviors", diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 3e725d997..2fdbee5ad 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -21,7 +21,7 @@ from .exceptions import AuthenticationException, SmartDeviceException from .json import dumps as json_dumps from .json import loads as json_loads -from .protocol import TPLinkTransport +from .protocol import BaseTransport _LOGGER = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def _sha1(payload: bytes) -> str: return sha1_algo.hexdigest() -class TPLinkAesTransport(TPLinkTransport): +class AesTransport(BaseTransport): """Implementation of the AES encryption protocol. AES is the name used in device discovery for TP-Link's TAPO encryption diff --git a/kasa/device_factory.py b/kasa/device_factory.py index cb1649f12..be293ee27 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -4,19 +4,19 @@ import time from typing import Any, Dict, Optional, Tuple, Type -from .aestransport import TPLinkAesTransport +from .aestransport import AesTransport from .credentials import Credentials from .device_type import DeviceType from .exceptions import UnsupportedDeviceException -from .iotprotocol import TPLinkIotProtocol -from .klaptransport import TPLinkKlapTransport, TPlinkKlapTransportV2 -from .protocol import TPLinkProtocol, TPLinkTransport +from .iotprotocol import IotProtocol +from .klaptransport import KlapTransport, TPlinkKlapTransportV2 +from .protocol import BaseTransport, TPLinkProtocol from .smartbulb import SmartBulb from .smartdevice import SmartDevice, SmartDeviceException from .smartdimmer import SmartDimmer from .smartlightstrip import SmartLightStrip from .smartplug import SmartPlug -from .smartprotocol import TPLinkSmartProtocol +from .smartprotocol import SmartProtocol from .smartstrip import SmartStrip from .tapo.tapoplug import TapoPlug @@ -150,20 +150,18 @@ def get_protocol_from_connection_name( ) -> Optional[TPLinkProtocol]: """Return the protocol from the connection name.""" supported_device_protocols: dict[ - str, Tuple[Type[TPLinkProtocol], Type[TPLinkTransport]] + str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] ] = { - "IOT.KLAP": (TPLinkIotProtocol, TPLinkKlapTransport), - "SMART.AES": (TPLinkSmartProtocol, TPLinkAesTransport), - "SMART.KLAP": (TPLinkSmartProtocol, TPlinkKlapTransportV2), + "IOT.KLAP": (IotProtocol, KlapTransport), + "SMART.AES": (SmartProtocol, AesTransport), + "SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2), } if connection_name not in supported_device_protocols: return None - protocol_transport_tuple = supported_device_protocols.get(connection_name) - transport: TPLinkTransport = protocol_transport_tuple[1]( # type: ignore - host, credentials=credentials - ) - protocol: TPLinkProtocol = protocol_transport_tuple[0]( # type: ignore + protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore + transport: BaseTransport = transport_class(host, credentials=credentials) + protocol: TPLinkProtocol = protocol_class( host, credentials=credentials, transport=transport ) return protocol diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index 808594540..54ba3e510 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -8,13 +8,13 @@ from .credentials import Credentials from .exceptions import AuthenticationException, SmartDeviceException from .json import dumps as json_dumps -from .klaptransport import TPLinkKlapTransport -from .protocol import TPLinkProtocol, TPLinkTransport +from .klaptransport import KlapTransport +from .protocol import BaseTransport, TPLinkProtocol _LOGGER = logging.getLogger(__name__) -class TPLinkIotProtocol(TPLinkProtocol): +class IotProtocol(TPLinkProtocol): """Class for the legacy TPLink IOT KASA Protocol.""" DEFAULT_PORT = 80 @@ -23,7 +23,7 @@ def __init__( self, host: str, *, - transport: Optional[TPLinkTransport] = None, + transport: Optional[BaseTransport] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: @@ -34,12 +34,8 @@ def __init__( if credentials and credentials.username and credentials.password else Credentials(username="", password="") ) - self.transport: TPLinkTransport = ( - transport - if transport - else TPLinkKlapTransport( - host, credentials=self.credentials, timeout=timeout - ) + self.transport: BaseTransport = transport or KlapTransport( + host, credentials=self.credentials, timeout=timeout ) self.query_lock = asyncio.Lock() diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 4ad37d906..72d49cbfe 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -56,7 +56,7 @@ from .credentials import Credentials from .exceptions import AuthenticationException, SmartDeviceException from .json import loads as json_loads -from .protocol import TPLinkTransport, md5 +from .protocol import BaseTransport, md5 _LOGGER = logging.getLogger(__name__) logging.getLogger("httpx").propagate = False @@ -75,7 +75,7 @@ def _sha1(payload: bytes) -> bytes: return digest.finalize() -class TPLinkKlapTransport(TPLinkTransport): +class KlapTransport(BaseTransport): """Implementation of the KLAP encryption protocol. KLAP is the name used in device discovery for TP-Link's new encryption @@ -408,7 +408,7 @@ def generate_owner_hash(creds: Credentials): return md5(un.encode()) -class TPlinkKlapTransportV2(TPLinkKlapTransport): +class TPlinkKlapTransportV2(KlapTransport): """Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" @staticmethod diff --git a/kasa/protocol.py b/kasa/protocol.py index acd08723d..f52dbc879 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -41,7 +41,7 @@ def md5(payload: bytes) -> bytes: return hash -class TPLinkTransport(ABC): +class BaseTransport(ABC): """Base class for all TP-Link KASA-KLAP and TAPO transports.""" def __init__( @@ -90,7 +90,7 @@ def __init__( *, port: Optional[int] = None, credentials: Optional[Credentials] = None, - transport: Optional[TPLinkTransport] = None, + transport: Optional[BaseTransport] = None, ) -> None: """Create a protocol object.""" self.host = host diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 962c26ef7..2ec6df893 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -14,17 +14,17 @@ import httpx -from .aestransport import TPLinkAesTransport +from .aestransport import AesTransport from .credentials import Credentials from .exceptions import AuthenticationException, SmartDeviceException from .json import dumps as json_dumps -from .protocol import TPLinkProtocol, TPLinkTransport, md5 +from .protocol import BaseTransport, TPLinkProtocol, md5 _LOGGER = logging.getLogger(__name__) logging.getLogger("httpx").propagate = False -class TPLinkSmartProtocol(TPLinkProtocol): +class SmartProtocol(TPLinkProtocol): """Class for the new TPLink SMART protocol.""" DEFAULT_PORT = 80 @@ -33,7 +33,7 @@ def __init__( self, host: str, *, - transport: Optional[TPLinkTransport] = None, + transport: Optional[BaseTransport] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, ) -> None: @@ -44,10 +44,8 @@ def __init__( if credentials and credentials.username and credentials.password else Credentials(username="", password="") ) - self.transport: TPLinkTransport = ( - transport - if transport - else TPLinkAesTransport(host, credentials=self.credentials, timeout=timeout) + self.transport: BaseTransport = transport or AesTransport( + host, credentials=self.credentials, timeout=timeout ) self.terminal_uuid: Optional[str] = None self.request_id_generator = SnowflakeId(1, 1) diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 701063739..6c643a6af 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -7,7 +7,7 @@ from ..credentials import Credentials from ..exceptions import AuthenticationException from ..smartdevice import SmartDevice -from ..smartprotocol import TPLinkSmartProtocol +from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) @@ -26,9 +26,7 @@ def __init__( super().__init__(host, port=port, credentials=credentials, timeout=timeout) self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None - self.protocol = TPLinkSmartProtocol( - host, credentials=credentials, timeout=timeout - ) + self.protocol = SmartProtocol(host, credentials=credentials, timeout=timeout) async def update(self, update_children: bool = True): """Update the device.""" diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 387cc116d..50d2f0def 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -21,13 +21,26 @@ SmartStrip, TPLinkSmartHomeProtocol, ) -from kasa.tapo import TapoPlug +from kasa.tapo import TapoDevice, TapoPlug from .newfakes import FakeSmartProtocol, FakeTransportProtocol -SUPPORTED_DEVICES = glob.glob( - os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json" -) +SUPPORTED_IOT_DEVICES = [ + (device, "IOT") + for device in glob.glob( + os.path.dirname(os.path.abspath(__file__)) + "/fixtures/*.json" + ) +] + +SUPPORTED_SMART_DEVICES = [ + (device, "SMART") + for device in glob.glob( + os.path.dirname(os.path.abspath(__file__)) + "/fixtures/smart/*.json" + ) +] + + +SUPPORTED_DEVICES = SUPPORTED_IOT_DEVICES + SUPPORTED_SMART_DEVICES LIGHT_STRIPS = {"KL400", "KL430", "KL420"} @@ -77,35 +90,42 @@ IP_MODEL_CACHE: Dict[str, str] = {} -def filter_model(desc, filter, is_smart_protocol=False): +def idgenerator(paramtuple): + return basename(paramtuple[0]) + ( + "" if paramtuple[1] == "IOT" else "-" + paramtuple[1] + ) + + +def filter_model(desc, model_filter, protocol_filter=None): + if not protocol_filter: + protocol_filter = {"IOT"} filtered = list() - for dev in SUPPORTED_DEVICES: - for filt in filter: - model_name = filt - if is_smart_protocol: - model_name = model_name + ".smart" - if model_name in basename(dev).split("_")[0]: - filtered.append(dev) - - filtered_basenames = [basename(f) for f in filtered] + for file, protocol in SUPPORTED_DEVICES: + if protocol in protocol_filter: + file_model = basename(file).split("_")[0] + for model in model_filter: + if model in file_model: + filtered.append((file, protocol)) + + filtered_basenames = [basename(f) + "-" + p for f, p in filtered] print(f"{desc}: {filtered_basenames}") return filtered -def parametrize(desc, devices, ids=None, is_smart_protocol=False): +def parametrize(desc, devices, protocol_filter=None, ids=None): return pytest.mark.parametrize( - "dev", filter_model(desc, devices), indirect=True, ids=ids + "dev", filter_model(desc, devices, protocol_filter), indirect=True, ids=ids ) has_emeter = parametrize("has emeter", WITH_EMETER) no_emeter = parametrize("no emeter", ALL_DEVICES_IOT - WITH_EMETER) -bulb = parametrize("bulbs", BULBS, ids=basename) -plug = parametrize("plugs", PLUGS, ids=basename) -strip = parametrize("strips", STRIPS, ids=basename) -dimmer = parametrize("dimmers", DIMMERS, ids=basename) -lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=basename) +bulb = parametrize("bulbs", BULBS, ids=idgenerator) +plug = parametrize("plugs", PLUGS, ids=idgenerator) +strip = parametrize("strips", STRIPS, ids=idgenerator) +dimmer = parametrize("dimmers", DIMMERS, ids=idgenerator) +lightstrip = parametrize("lightstrips", LIGHT_STRIPS, ids=idgenerator) # bulb types dimmable = parametrize("dimmable", DIMMABLE) @@ -116,23 +136,26 @@ def parametrize(desc, devices, ids=None, is_smart_protocol=False): non_color_bulb = parametrize("non-color bulbs", BULBS - COLOR_BULBS) plug_smart = parametrize( - "plug devices smart", PLUGS_SMART, ids=basename, is_smart_protocol=True + "plug devices smart", PLUGS_SMART, protocol_filter={"SMART"}, ids=idgenerator ) device_smart = parametrize( - "devices smart", ALL_DEVICES_SMART, ids=basename, is_smart_protocol=True + "devices smart", ALL_DEVICES_SMART, protocol_filter={"SMART"}, ids=idgenerator ) device_iot = parametrize( - "devices iot", ALL_DEVICES_IOT, ids=basename, is_smart_protocol=False + "devices iot", ALL_DEVICES_IOT, protocol_filter={"IOT"}, ids=idgenerator ) def get_fixture_data(): """Return raw discovery file contents as JSON. Used for discovery tests.""" fixture_data = {} - for file in SUPPORTED_DEVICES: + for file, protocol in SUPPORTED_DEVICES: p = Path(file) if not p.is_absolute(): - p = Path(__file__).parent / "fixtures" / file + folder = Path(__file__).parent / "fixtures" + if protocol == "SMART": + folder = folder / "smart" + p = folder / file with open(p) as f: fixture_data[basename(p)] = json.load(f) @@ -177,10 +200,9 @@ def check_categories(): ) diff = set(SUPPORTED_DEVICES) - set(categorized_fixtures) if diff: - for file in diff: + for file, protocol in diff: print( - "No category for file %s, add to the corresponding set (BULBS, PLUGS, ..)" - % file + f"No category for file {file} protocol {protocol}, add to the corresponding set (BULBS, PLUGS, ..)" ) raise Exception(f"Missing category for {diff}") @@ -198,31 +220,32 @@ async def handle_turn_on(dev, turn_on): await dev.turn_off() -def device_for_file(model): - for d in STRIPS: - if d in model: - return SmartStrip - - for d in PLUGS: - if d in model: - return SmartPlug +def device_for_file(model, protocol): + if protocol == "SMART": + for d in PLUGS_SMART: + if d in model: + return TapoPlug + else: + for d in STRIPS: + if d in model: + return SmartStrip - # Light strips are recognized also as bulbs, so this has to go first - for d in LIGHT_STRIPS: - if d in model: - return SmartLightStrip + for d in PLUGS: + if d in model: + return SmartPlug - for d in BULBS: - if d in model: - return SmartBulb + # Light strips are recognized also as bulbs, so this has to go first + for d in LIGHT_STRIPS: + if d in model: + return SmartLightStrip - for d in DIMMERS: - if d in model: - return SmartDimmer + for d in BULBS: + if d in model: + return SmartBulb - for d in PLUGS_SMART: - if d + ".smart" in model: - return TapoPlug + for d in DIMMERS: + if d in model: + return SmartDimmer raise Exception("Unable to find type for %s", model) @@ -238,11 +261,14 @@ async def _discover_update_and_close(ip): return await _update_and_close(d) -async def get_device_for_file(file): +async def get_device_for_file(file, protocol): # if the wanted file is not an absolute path, prepend the fixtures directory p = Path(file) if not p.is_absolute(): - p = Path(__file__).parent / "fixtures" / file + folder = Path(__file__).parent / "fixtures" + if protocol == "SMART": + folder = folder / "smart" + p = folder / file def load_file(): with open(p) as f: @@ -252,8 +278,8 @@ def load_file(): sysinfo = await loop.run_in_executor(None, load_file) model = basename(file) - d = device_for_file(model)(host="127.0.0.123") - if ".smart" in model.split("_")[0]: + d = device_for_file(model, protocol)(host="127.0.0.123") + if protocol == "SMART": d.protocol = FakeSmartProtocol(sysinfo) d.credentials = Credentials("", "") else: @@ -269,7 +295,7 @@ async def dev(request): Provides a device (given --ip) or parametrized fixture for the supported devices. The initial update is called automatically before returning the device. """ - file = request.param + file, protocol = request.param ip = request.config.getoption("--ip") if ip: @@ -282,7 +308,7 @@ async def dev(request): pytest.skip(f"skipping file {file}") return d if d else await _discover_update_and_close(ip) - return await get_device_for_file(file) + return await get_device_for_file(file, protocol) @pytest.fixture diff --git a/kasa/tests/fixtures/P110.smart_1.0_1.3.0.json b/kasa/tests/fixtures/smart/P110_1.0_1.3.0.json similarity index 100% rename from kasa/tests/fixtures/P110.smart_1.0_1.3.0.json rename to kasa/tests/fixtures/smart/P110_1.0_1.3.0.json diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index ceb2f5cbb..0111c3d75 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -14,8 +14,8 @@ Schema, ) -from ..protocol import TPLinkSmartHomeProtocol, TPLinkTransport -from ..smartprotocol import TPLinkSmartProtocol +from ..protocol import BaseTransport, TPLinkSmartHomeProtocol +from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) @@ -287,12 +287,12 @@ def success(res): } -class FakeSmartProtocol(TPLinkSmartProtocol): +class FakeSmartProtocol(SmartProtocol): def __init__(self, info): super().__init__("127.0.0.123", transport=FakeSmartTransport(info)) -class FakeSmartTransport(TPLinkTransport): +class FakeSmartTransport(BaseTransport): def __init__(self, info): self.info = info diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index 514a93ab0..55e3977af 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -8,7 +8,7 @@ from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle from kasa.device_factory import DEVICE_TYPE_TO_CLASS from kasa.discover import Discover -from kasa.smartprotocol import TPLinkSmartProtocol +from kasa.smartprotocol import SmartProtocol from .conftest import device_iot, handle_turn_on, new_discovery, turn_on from .newfakes import FakeSmartProtocol, FakeTransportProtocol diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 4009f247a..eb12b3b0d 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -21,7 +21,7 @@ get_protocol_from_connection_name, ) from kasa.discover import DiscoveryResult -from kasa.iotprotocol import TPLinkIotProtocol +from kasa.iotprotocol import IotProtocol from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol @@ -111,8 +111,8 @@ async def test_connect_pass_protocol( ] host = "127.0.0.1" if "discovery_result" in all_fixture_data: - mocker.patch("kasa.TPLinkIotProtocol.query", return_value=all_fixture_data) - mocker.patch("kasa.TPLinkSmartProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) dr = DiscoveryResult(**discovery_info["result"]) connection_name = ( diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 8dd49013b..6807fcff2 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -10,12 +10,12 @@ import httpx import pytest -from ..aestransport import TPLinkAesTransport +from ..aestransport import AesTransport from ..credentials import Credentials from ..exceptions import AuthenticationException, SmartDeviceException -from ..iotprotocol import TPLinkIotProtocol -from ..klaptransport import KlapEncryptionSession, TPLinkKlapTransport, _sha256 -from ..smartprotocol import TPLinkSmartProtocol +from ..iotprotocol import IotProtocol +from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 +from ..smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} @@ -26,8 +26,8 @@ def __init__(self, status_code, content: bytes): self.content = content -@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) -@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): host = "127.0.0.1" @@ -42,8 +42,8 @@ async def test_protocol_retries(mocker, retry_count, protocol_class, transport_c assert conn.call_count == retry_count + 1 -@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) -@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) async def test_protocol_no_retry_on_connection_error( mocker, protocol_class, transport_class ): @@ -61,8 +61,8 @@ async def test_protocol_no_retry_on_connection_error( assert conn.call_count == 1 -@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) -@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) async def test_protocol_retry_recoverable_error( mocker, protocol_class, transport_class ): @@ -80,8 +80,8 @@ async def test_protocol_retry_recoverable_error( assert conn.call_count == 6 -@pytest.mark.parametrize("transport_class", [TPLinkAesTransport, TPLinkKlapTransport]) -@pytest.mark.parametrize("protocol_class", [TPLinkIotProtocol, TPLinkSmartProtocol]) +@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) +@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class): host = "127.0.0.1" @@ -123,16 +123,14 @@ def _return_encrypted(*_, **__): return 200, encrypted seed = secrets.token_bytes(16) - auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = TPLinkIotProtocol("127.0.0.1") + protocol = IotProtocol("127.0.0.1") protocol.transport.handshake_done = True protocol.transport.session_expire_at = time.time() + 86400 protocol.transport.encryption_session = encryption_session - mocker.patch.object( - TPLinkKlapTransport, "client_post", side_effect=_return_encrypted - ) + mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted) response = await protocol.query({}) assert response == {"great": "success"} @@ -146,7 +144,7 @@ def test_encrypt(): d = json.dumps({"foo": 1, "bar": 2}) seed = secrets.token_bytes(16) - auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) @@ -158,7 +156,7 @@ def test_encrypt_unicode(): d = "{'snowman': '\u2603'}" seed = secrets.token_bytes(16) - auth_hash = TPLinkKlapTransport.generate_auth_hash(Credentials("foo", "bar")) + auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) encrypted, seq = encryption_session.encrypt(d) @@ -175,8 +173,8 @@ def test_encrypt_unicode(): (Credentials("", ""), does_not_raise()), ( Credentials( - TPLinkKlapTransport.KASA_SETUP_EMAIL, - TPLinkKlapTransport.KASA_SETUP_PASSWORD, + KlapTransport.KASA_SETUP_EMAIL, + KlapTransport.KASA_SETUP_PASSWORD, ), does_not_raise(), ), @@ -199,13 +197,13 @@ async def _return_handshake1_response(url, params=None, data=None, *_, **__): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlapTransport.generate_auth_hash(device_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(device_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake1_response ) - protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) protocol.transport.http_client = httpx.AsyncClient() with expectation: @@ -236,13 +234,13 @@ async def _return_handshake_response(url, params=None, data=None, *_, **__): client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlapTransport.generate_auth_hash(client_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake_response ) - protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) protocol.transport.http_client = httpx.AsyncClient() response_status = 200 @@ -284,11 +282,11 @@ async def _return_response(url, params=None, data=None, *_, **__): seq = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlapTransport.generate_auth_hash(client_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) for _ in range(10): resp = await protocol.query({}) @@ -328,11 +326,11 @@ async def _return_response(url, params=None, data=None, *_, **__): server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = TPLinkKlapTransport.generate_auth_hash(client_credentials) + device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = TPLinkIotProtocol("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol("127.0.0.1", credentials=client_credentials) with expectation: await protocol.query({}) diff --git a/kasa/tests/test_readme_examples.py b/kasa/tests/test_readme_examples.py index 13c6e9944..5772ba42c 100644 --- a/kasa/tests/test_readme_examples.py +++ b/kasa/tests/test_readme_examples.py @@ -9,7 +9,7 @@ def test_bulb_examples(mocker): """Use KL130 (bulb with all features) to test the doctests.""" - p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json")) + p = asyncio.run(get_device_for_file("KL130(US)_1.0_1.8.11.json", "IOT")) mocker.patch("kasa.smartbulb.SmartBulb", return_value=p) mocker.patch("kasa.smartbulb.SmartBulb.update") res = xdoctest.doctest_module("kasa.smartbulb", "all") @@ -18,7 +18,7 @@ def test_bulb_examples(mocker): def test_smartdevice_examples(mocker): """Use HS110 for emeter examples.""" - p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json")) + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT")) mocker.patch("kasa.smartdevice.SmartDevice", return_value=p) mocker.patch("kasa.smartdevice.SmartDevice.update") res = xdoctest.doctest_module("kasa.smartdevice", "all") @@ -27,7 +27,7 @@ def test_smartdevice_examples(mocker): def test_plug_examples(mocker): """Test plug examples.""" - p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json")) + p = asyncio.run(get_device_for_file("HS110(EU)_1.0_1.2.5.json", "IOT")) mocker.patch("kasa.smartplug.SmartPlug", return_value=p) mocker.patch("kasa.smartplug.SmartPlug.update") res = xdoctest.doctest_module("kasa.smartplug", "all") @@ -36,7 +36,7 @@ def test_plug_examples(mocker): def test_strip_examples(mocker): """Test strip examples.""" - p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json")) + p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT")) mocker.patch("kasa.smartstrip.SmartStrip", return_value=p) mocker.patch("kasa.smartstrip.SmartStrip.update") res = xdoctest.doctest_module("kasa.smartstrip", "all") @@ -45,7 +45,7 @@ def test_strip_examples(mocker): def test_dimmer_examples(mocker): """Test dimmer examples.""" - p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json")) + p = asyncio.run(get_device_for_file("HS220(US)_1.0_1.5.7.json", "IOT")) mocker.patch("kasa.smartdimmer.SmartDimmer", return_value=p) mocker.patch("kasa.smartdimmer.SmartDimmer.update") res = xdoctest.doctest_module("kasa.smartdimmer", "all") @@ -54,7 +54,7 @@ def test_dimmer_examples(mocker): def test_lightstrip_examples(mocker): """Test lightstrip examples.""" - p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json")) + p = asyncio.run(get_device_for_file("KL430(US)_1.0_1.0.10.json", "IOT")) mocker.patch("kasa.smartlightstrip.SmartLightStrip", return_value=p) mocker.patch("kasa.smartlightstrip.SmartLightStrip.update") res = xdoctest.doctest_module("kasa.smartlightstrip", "all") @@ -63,7 +63,7 @@ def test_lightstrip_examples(mocker): def test_discovery_examples(mocker): """Test discovery examples.""" - p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json")) + p = asyncio.run(get_device_for_file("KP303(UK)_1.0_1.0.3.json", "IOT")) mocker.patch("kasa.discover.Discover.discover", return_value=[p]) res = xdoctest.doctest_module("kasa.discover", "all") From 63310d32dba813d13d4af7359d7f7390456cfadb Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Mon, 4 Dec 2023 18:21:41 +0000 Subject: [PATCH 4/4] Updates following review --- kasa/aestransport.py | 125 +++++++++++++++----------------- kasa/iotprotocol.py | 24 +++--- kasa/klaptransport.py | 105 +++++++++++++-------------- kasa/protocol.py | 4 +- kasa/smartprotocol.py | 36 +++++---- kasa/tests/newfakes.py | 4 +- kasa/tests/test_klapprotocol.py | 33 +++++---- 7 files changed, 164 insertions(+), 167 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 2fdbee5ad..6757013da 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -4,12 +4,11 @@ under compatible GNU GPL3 license. """ -import asyncio import base64 import hashlib import logging import time -from typing import Optional +from typing import Optional, Union import httpx from cryptography.hazmat.primitives import padding, serialization @@ -56,60 +55,51 @@ def __init__( ) -> None: super().__init__(host=host) - self.credentials = ( - credentials - if credentials and credentials.username and credentials.password - else Credentials(username="", password="") - ) - - self._local_seed: Optional[bytes] = None - self.kasa_setup_auth_hash = None - self.blank_auth_hash = None - self.handshake_lock = asyncio.Lock() + self._credentials = credentials or Credentials(username="", password="") - self.handshake_done = False + self._handshake_done = False - self.encryption_session: Optional[AesEncyptionSession] = None - self.session_expire_at: Optional[float] = None + self._encryption_session: Optional[AesEncyptionSession] = None + self._session_expire_at: Optional[float] = None - self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT - self.session_cookie = None + self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT + self._session_cookie = None - self.http_client: httpx.AsyncClient = httpx.AsyncClient() - self.login_token = None + self._http_client: httpx.AsyncClient = httpx.AsyncClient() + self._login_token = None _LOGGER.debug("Created AES object for %s", self.host) - def hash_credentials(self, credentials, try_login_version2): + def hash_credentials(self, login_v2): """Hash the credentials.""" - if try_login_version2: + if login_v2: un = base64.b64encode( - _sha1(credentials.username.encode()).encode() + _sha1(self._credentials.username.encode()).encode() ).decode() pw = base64.b64encode( - _sha1(credentials.password.encode()).encode() + _sha1(self._credentials.password.encode()).encode() ).decode() else: un = base64.b64encode( - _sha1(credentials.username.encode()).encode() + _sha1(self._credentials.username.encode()).encode() ).decode() - pw = base64.b64encode(credentials.password.encode()).decode() + pw = base64.b64encode(self._credentials.password.encode()).decode() return un, pw async def client_post(self, url, params=None, data=None, json=None, headers=None): """Send an http post request to the device.""" response_data = None cookies = None - if self.session_cookie: + if self._session_cookie: cookies = httpx.Cookies() - cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie) - self.http_client.cookies.clear() - resp = await self.http_client.post( + cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie) + self._http_client.cookies.clear() + resp = await self._http_client.post( url, params=params, data=data, json=json, - timeout=self.timeout, + timeout=self._timeout, cookies=cookies, headers=self.COMMON_HEADERS, ) @@ -121,10 +111,10 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None async def send_secure_passthrough(self, request: str): """Send encrypted message as passthrough.""" url = f"http://{self.host}/app" - if self.login_token: - url += f"?token={self.login_token}" + if self._login_token: + url += f"?token={self._login_token}" - encrypted_payload = self.encryption_session.encrypt(request.encode()) # type: ignore + encrypted_payload = self._encryption_session.encrypt(request.encode()) # type: ignore passthrough_request = { "method": "securePassthrough", "params": {"request": encrypted_payload.decode()}, @@ -132,52 +122,56 @@ async def send_secure_passthrough(self, request: str): status_code, resp_dict = await self.client_post(url, json=passthrough_request) _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}") if status_code == 200 and resp_dict["error_code"] == 0: - response = self.encryption_session.decrypt( # type: ignore + response = self._encryption_session.decrypt( # type: ignore resp_dict["result"]["response"].encode() ) _LOGGER.debug(f"decrypted secure_passthrough response is {response}") resp_dict = json_loads(response) return resp_dict else: - self.handshake_done = False - self.login_token = None + self._handshake_done = False + self._login_token = None raise AuthenticationException("Could not complete send") - async def perform_login(self, login_request, login_v2): + async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool): """Login to the device.""" - self.login_token = None + self._login_token = None if isinstance(login_request, str): - login_request = json_loads(login_request) + login_request_dict: dict = json_loads(login_request) + else: + login_request_dict = login_request - un, pw = self.hash_credentials(self.credentials, login_v2) - login_request["params"] = {"password": pw, "username": un} - request = json_dumps(login_request) + un, pw = self.hash_credentials(login_v2) + login_request_dict["params"] = {"password": pw, "username": un} + request = json_dumps(login_request_dict) try: resp_dict = await self.send_secure_passthrough(request) except SmartDeviceException as ex: raise AuthenticationException(ex) from ex - self.login_token = resp_dict["result"]["token"] + self._login_token = resp_dict["result"]["token"] + @property def needs_login(self) -> bool: """Return true if the transport needs to do a login.""" - return self.login_token is None + return self._login_token is None async def login(self, request: str) -> None: """Login to the device.""" try: - if self.needs_handshake(): + if self.needs_handshake: raise SmartDeviceException( "Handshake must be complete before trying to login" ) - await self.perform_login(request, False) + await self.perform_login(request, login_v2=False) except AuthenticationException: await self.perform_handshake() - await self.perform_login(request, True) + await self.perform_login(request, login_v2=True) + @property def needs_handshake(self) -> bool: """Return true if the transport needs to do a handshake.""" - return not self.handshake_done or self.handshake_session_expired() + return not self._handshake_done or self._handshake_session_expired() async def handshake(self) -> None: """Perform the encryption handshake.""" @@ -188,9 +182,9 @@ async def perform_handshake(self): _LOGGER.debug("Will perform handshaking...") _LOGGER.debug("Generating keypair") - self.handshake_done = False - self.session_expire_at = None - self.session_cookie = None + self._handshake_done = False + self._session_expire_at = None + self._session_cookie = None url = f"http://{self.host}/app" key_pair = KeyPair.create_key_pair() @@ -215,45 +209,46 @@ async def perform_handshake(self): _LOGGER.debug("Decoding handshake key...") handshake_key = resp_dict["result"]["key"] - self.session_cookie = self.http_client.cookies.get( # type: ignore + self._session_cookie = self._http_client.cookies.get( # type: ignore self.SESSION_COOKIE_NAME ) - if not self.session_cookie: - self.session_cookie = self.http_client.cookies.get( # type: ignore + if not self._session_cookie: + self._session_cookie = self._http_client.cookies.get( # type: ignore "SESSIONID" ) - self.session_expire_at = time.time() + 86400 - self.encryption_session = AesEncyptionSession.create_from_keypair( + self._session_expire_at = time.time() + 86400 + self._encryption_session = AesEncyptionSession.create_from_keypair( handshake_key, key_pair ) - self.handshake_done = True + self._handshake_done = True _LOGGER.debug("Handshake with %s complete", self.host) else: raise AuthenticationException("Could not complete handshake") - def handshake_session_expired(self): + def _handshake_session_expired(self): """Return true if session has expired.""" return ( - self.session_expire_at is None or self.session_expire_at - time.time() <= 0 + self._session_expire_at is None + or self._session_expire_at - time.time() <= 0 ) async def send(self, request: str): """Send the request.""" - if self.needs_handshake(): + if self.needs_handshake: raise SmartDeviceException( "Handshake must be complete before trying to send" ) - if self.needs_login(): + if self.needs_login: raise SmartDeviceException("Login must be complete before trying to send") resp_dict = await self.send_secure_passthrough(request) if resp_dict["error_code"] != 0: - self.handshake_done = False - self.login_token = None + self._handshake_done = False + self._login_token = None raise SmartDeviceException( f"Could not complete send, response was {resp_dict}", ) @@ -261,8 +256,8 @@ async def send(self, request: str): async def close(self) -> None: """Close the protocol.""" - client = self.http_client - self.http_client = None + client = self._http_client + self._http_client = None if client: await client.aclose() diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index 54ba3e510..2b7f422db 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -29,16 +29,14 @@ def __init__( ) -> None: super().__init__(host=host, port=self.DEFAULT_PORT) - self.credentials: Credentials = ( - credentials - if credentials and credentials.username and credentials.password - else Credentials(username="", password="") + self._credentials: Credentials = credentials or Credentials( + username="", password="" ) - self.transport: BaseTransport = transport or KlapTransport( - host, credentials=self.credentials, timeout=timeout + self._transport: BaseTransport = transport or KlapTransport( + host, credentials=self._credentials, timeout=timeout ) - self.query_lock = asyncio.Lock() + self._query_lock = asyncio.Lock() async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Query the device retrying for retry_count on failure.""" @@ -46,7 +44,7 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: request = json_dumps(request) assert isinstance(request, str) # noqa: S101 - async with self.query_lock: + async with self._query_lock: return await self._query(request, retry_count) async def _query(self, request: str, retry_count: int = 3) -> Dict: @@ -87,16 +85,16 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: raise SmartDeviceException("Query reached somehow to unreachable") async def _execute_query(self, request: str, retry_count: int) -> Dict: - if self.transport.needs_handshake(): - await self.transport.handshake() + if self._transport.needs_handshake: + await self._transport.handshake() - if self.transport.needs_login(): # This shouln't happen + if self._transport.needs_login: # This shouln't happen raise SmartDeviceException( "IOT Protocol needs to login to transport but is not login aware" ) - return await self.transport.send(request) + return await self._transport.send(request) async def close(self) -> None: """Close the protocol.""" - await self.transport.close() + await self._transport.close() diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 72d49cbfe..c28cb0354 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -97,26 +97,22 @@ def __init__( ) -> None: super().__init__(host=host) - self.credentials = ( - credentials - if credentials and credentials.username and credentials.password - else Credentials(username="", password="") - ) + self._credentials = credentials or Credentials(username="", password="") self._local_seed: Optional[bytes] = None - self.local_auth_hash = self.generate_auth_hash(self.credentials) - self.local_auth_owner = self.generate_owner_hash(self.credentials).hex() - self.kasa_setup_auth_hash = None - self.blank_auth_hash = None - self.handshake_lock = asyncio.Lock() - self.query_lock = asyncio.Lock() - self.handshake_done = False + self._local_auth_hash = self.generate_auth_hash(self._credentials) + self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() + self._kasa_setup_auth_hash = None + self._blank_auth_hash = None + self._handshake_lock = asyncio.Lock() + self._query_lock = asyncio.Lock() + self._handshake_done = False - self.encryption_session: Optional[KlapEncryptionSession] = None - self.session_expire_at: Optional[float] = None + self._encryption_session: Optional[KlapEncryptionSession] = None + self._session_expire_at: Optional[float] = None - self.timeout = timeout if timeout else self.DEFAULT_TIMEOUT - self.session_cookie = None - self.http_client: httpx.AsyncClient = httpx.AsyncClient() + self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT + self._session_cookie = None + self._http_client: httpx.AsyncClient = httpx.AsyncClient() _LOGGER.debug("Created KLAP object for %s", self.host) @@ -124,15 +120,15 @@ async def client_post(self, url, params=None, data=None): """Send an http post request to the device.""" response_data = None cookies = None - if self.session_cookie: + if self._session_cookie: cookies = httpx.Cookies() - cookies.set(self.SESSION_COOKIE_NAME, self.session_cookie) - self.http_client.cookies.clear() - resp = await self.http_client.post( + cookies.set(self.SESSION_COOKIE_NAME, self._session_cookie) + self._http_client.cookies.clear() + resp = await self._http_client.post( url, params=params, data=data, - timeout=self.timeout, + timeout=self._timeout, cookies=cookies, ) if resp.status_code == 200: @@ -183,26 +179,26 @@ async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]: ) local_seed_auth_hash = self.handshake1_seed_auth_hash( - local_seed, remote_seed, self.local_auth_hash + local_seed, remote_seed, self._local_auth_hash ) # type: ignore # Check the response from the device with local credentials if local_seed_auth_hash == server_hash: _LOGGER.debug("handshake1 hashes match with expected credentials") - return local_seed, remote_seed, self.local_auth_hash # type: ignore + return local_seed, remote_seed, self._local_auth_hash # type: ignore # Now check against the default kasa setup credentials - if not self.kasa_setup_auth_hash: + if not self._kasa_setup_auth_hash: kasa_setup_creds = Credentials( username=self.KASA_SETUP_EMAIL, password=self.KASA_SETUP_PASSWORD, ) - self.kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds) + self._kasa_setup_auth_hash = self.generate_auth_hash(kasa_setup_creds) kasa_setup_seed_auth_hash = self.handshake1_seed_auth_hash( local_seed, remote_seed, - self.kasa_setup_auth_hash, # type: ignore + self._kasa_setup_auth_hash, # type: ignore ) if kasa_setup_seed_auth_hash == server_hash: @@ -211,17 +207,17 @@ async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]: + " but an authentication with kasa setup credentials matched", self.host, ) - return local_seed, remote_seed, self.kasa_setup_auth_hash # type: ignore + return local_seed, remote_seed, self._kasa_setup_auth_hash # type: ignore # Finally check against blank credentials if not already blank - if self.credentials != (blank_creds := Credentials(username="", password="")): - if not self.blank_auth_hash: - self.blank_auth_hash = self.generate_auth_hash(blank_creds) + if self._credentials != (blank_creds := Credentials(username="", password="")): + if not self._blank_auth_hash: + self._blank_auth_hash = self.generate_auth_hash(blank_creds) blank_seed_auth_hash = self.handshake1_seed_auth_hash( local_seed, remote_seed, - self.blank_auth_hash, # type: ignore + self._blank_auth_hash, # type: ignore ) if blank_seed_auth_hash == server_hash: @@ -230,7 +226,7 @@ async def perform_handshake1(self) -> Tuple[bytes, bytes, bytes]: + " but an authentication with blank credentials matched", self.host, ) - return local_seed, remote_seed, self.blank_auth_hash # type: ignore + return local_seed, remote_seed, self._blank_auth_hash # type: ignore msg = f"Server response doesn't match our challenge on ip {self.host}" _LOGGER.debug(msg) @@ -266,6 +262,7 @@ async def perform_handshake2( return KlapEncryptionSession(local_seed, remote_seed, auth_hash) + @property def needs_login(self) -> bool: """Will return false as KLAP does not do a login.""" return False @@ -276,9 +273,10 @@ async def login(self, request: str) -> None: "KLAP does not perform logins and return needs_login == False" ) - def needs_handshake(self): + @property + def needs_handshake(self) -> bool: """Return true if the transport needs to do a handshake.""" - return not self.handshake_done or self.handshake_session_expired() + return not self._handshake_done or self._handshake_session_expired() async def handshake(self) -> None: """Perform the encryption handshake.""" @@ -290,43 +288,44 @@ async def perform_handshake(self) -> Any: Sets the encryption_session if successful. """ _LOGGER.debug("Starting handshake with %s", self.host) - self.handshake_done = False - self.session_expire_at = None - self.session_cookie = None + self._handshake_done = False + self._session_expire_at = None + self._session_cookie = None local_seed, remote_seed, auth_hash = await self.perform_handshake1() - self.session_cookie = self.http_client.cookies.get( # type: ignore + self._session_cookie = self._http_client.cookies.get( # type: ignore self.SESSION_COOKIE_NAME ) # The device returns a TIMEOUT cookie on handshake1 which # it doesn't like to get back so we store the one we want - self.session_expire_at = time.time() + 86400 - self.encryption_session = await self.perform_handshake2( + self._session_expire_at = time.time() + 86400 + self._encryption_session = await self.perform_handshake2( local_seed, remote_seed, auth_hash ) - self.handshake_done = True + self._handshake_done = True _LOGGER.debug("Handshake with %s complete", self.host) - def handshake_session_expired(self): + def _handshake_session_expired(self): """Return true if session has expired.""" return ( - self.session_expire_at is None or self.session_expire_at - time.time() <= 0 + self._session_expire_at is None + or self._session_expire_at - time.time() <= 0 ) async def send(self, request: str): """Send the request.""" - if self.needs_handshake(): + if self.needs_handshake: raise SmartDeviceException( "Handshake must be complete before trying to send" ) - if self.needs_login(): + if self.needs_login: raise SmartDeviceException("Login must be complete before trying to send") # Check for mypy - if self.encryption_session is not None: - payload, seq = self.encryption_session.encrypt(request.encode()) + if self._encryption_session is not None: + payload, seq = self._encryption_session.encrypt(request.encode()) url = f"http://{self.host}/app/request" @@ -345,7 +344,7 @@ async def send(self, request: str): _LOGGER.error("Query failed after succesful authentication " + msg) # If we failed with a security error, force a new handshake next time. if response_status == 403: - self.handshake_done = False + self._handshake_done = False raise AuthenticationException( f"Got a security error from {self.host} after handshake " + "completed" @@ -359,8 +358,8 @@ async def send(self, request: str): _LOGGER.debug("Query posted " + msg) # Check for mypy - if self.encryption_session is not None: - decrypted_response = self.encryption_session.decrypt(response_data) + if self._encryption_session is not None: + decrypted_response = self._encryption_session.decrypt(response_data) json_payload = json_loads(decrypted_response) @@ -374,8 +373,8 @@ async def send(self, request: str): async def close(self) -> None: """Close the transport.""" - client = self.http_client - self.http_client = None + client = self._http_client + self._http_client = None if client: await client.aclose() diff --git a/kasa/protocol.py b/kasa/protocol.py index f52dbc879..62cd5fb63 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -42,7 +42,7 @@ def md5(payload: bytes) -> bytes: class BaseTransport(ABC): - """Base class for all TP-Link KASA-KLAP and TAPO transports.""" + """Base class for all TP-Link protocol transports.""" def __init__( self, @@ -56,10 +56,12 @@ def __init__( self.port = port self.credentials = credentials + @property @abstractmethod def needs_handshake(self) -> bool: """Return true if the transport needs to do a handshake.""" + @property @abstractmethod def needs_login(self) -> bool: """Return true if the transport needs to do a login.""" diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 2ec6df893..98d1a86d8 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -39,32 +39,30 @@ def __init__( ) -> None: super().__init__(host=host, port=self.DEFAULT_PORT) - self.credentials: Credentials = ( - credentials - if credentials and credentials.username and credentials.password - else Credentials(username="", password="") + self._credentials: Credentials = credentials or Credentials( + username="", password="" ) - self.transport: BaseTransport = transport or AesTransport( - host, credentials=self.credentials, timeout=timeout + self._transport: BaseTransport = transport or AesTransport( + host, credentials=self._credentials, timeout=timeout ) - self.terminal_uuid: Optional[str] = None - self.request_id_generator = SnowflakeId(1, 1) - self.query_lock = asyncio.Lock() + self._terminal_uuid: Optional[str] = None + self._request_id_generator = SnowflakeId(1, 1) + self._query_lock = asyncio.Lock() def get_smart_request(self, method, params=None) -> str: """Get a request message as a string.""" request = { "method": method, "params": params, - "requestID": self.request_id_generator.generate_id(), + "requestID": self._request_id_generator.generate_id(), "request_time_milis": round(time.time() * 1000), - "terminal_uuid": self.terminal_uuid, + "terminal_uuid": self._terminal_uuid, } return json_dumps(request) async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Query the device retrying for retry_count on failure.""" - async with self.query_lock: + async with self._query_lock: resp_dict = await self._query(request, retry_count) if "result" in resp_dict: return resp_dict["result"] @@ -115,18 +113,18 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D smart_method = request smart_params = None - if self.transport.needs_handshake(): - await self.transport.handshake() + if self._transport.needs_handshake: + await self._transport.handshake() - if self.transport.needs_login(): - self.terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode( + if self._transport.needs_login: + self._terminal_uuid = base64.b64encode(md5(uuid.uuid4().bytes)).decode( "UTF-8" ) login_request = self.get_smart_request("login_device") - await self.transport.login(login_request) + await self._transport.login(login_request) smart_request = self.get_smart_request(smart_method, smart_params) - response_data = await self.transport.send(smart_request) + response_data = await self._transport.send(smart_request) _LOGGER.debug( "%s << %s", @@ -138,7 +136,7 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D async def close(self) -> None: """Close the protocol.""" - await self.transport.close() + await self._transport.close() class SnowflakeId: diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 0111c3d75..c5bf238f8 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -296,9 +296,11 @@ class FakeSmartTransport(BaseTransport): def __init__(self, info): self.info = info + @property def needs_handshake(self) -> bool: return False + @property def needs_login(self) -> bool: return False @@ -314,7 +316,7 @@ async def send(self, request: str): if method == "component_nego" or method[:4] == "get_": return self.info[method] elif method[:4] == "set_": - pass + _LOGGER.debug("Call %s not implemented, doing nothing", method) async def close(self) -> None: pass diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 6807fcff2..fe4d1a6ca 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -96,8 +96,11 @@ def _fail_one_less_than_retry_count(*_, **__): return mock_response - mocker.patch.object(transport_class, "needs_handshake", return_value=False) - mocker.patch.object(transport_class, "needs_login", return_value=False) + mocker.patch.object( + transport_class, "needs_handshake", property(lambda self: False) + ) + mocker.patch.object(transport_class, "needs_login", property(lambda self: False)) + send_mock = mocker.patch.object( transport_class, "send", @@ -127,9 +130,9 @@ def _return_encrypted(*_, **__): encryption_session = KlapEncryptionSession(seed, seed, auth_hash) protocol = IotProtocol("127.0.0.1") - protocol.transport.handshake_done = True - protocol.transport.session_expire_at = time.time() + 86400 - protocol.transport.encryption_session = encryption_session + protocol._transport._handshake_done = True + protocol._transport._session_expire_at = time.time() + 86400 + protocol._transport._encryption_session = encryption_session mocker.patch.object(KlapTransport, "client_post", side_effect=_return_encrypted) response = await protocol.query({}) @@ -205,13 +208,13 @@ async def _return_handshake1_response(url, params=None, data=None, *_, **__): protocol = IotProtocol("127.0.0.1", credentials=client_credentials) - protocol.transport.http_client = httpx.AsyncClient() + protocol._transport.http_client = httpx.AsyncClient() with expectation: ( local_seed, device_remote_seed, auth_hash, - ) = await protocol.transport.perform_handshake1() + ) = await protocol._transport.perform_handshake1() assert local_seed == client_seed assert device_remote_seed == server_seed @@ -241,16 +244,16 @@ async def _return_handshake_response(url, params=None, data=None, *_, **__): ) protocol = IotProtocol("127.0.0.1", credentials=client_credentials) - protocol.transport.http_client = httpx.AsyncClient() + protocol._transport.http_client = httpx.AsyncClient() response_status = 200 - await protocol.transport.perform_handshake() - assert protocol.transport.handshake_done is True + await protocol._transport.perform_handshake() + assert protocol._transport._handshake_done is True response_status = 403 with pytest.raises(AuthenticationException): - await protocol.transport.perform_handshake() - assert protocol.transport.handshake_done is False + await protocol._transport.perform_handshake() + assert protocol._transport._handshake_done is False await protocol.close() @@ -267,9 +270,9 @@ async def _return_response(url, params=None, data=None, *_, **__): return _mock_response(200, b"") elif url == "http://127.0.0.1/app/request": encryption_session = KlapEncryptionSession( - protocol.transport.encryption_session.local_seed, - protocol.transport.encryption_session.remote_seed, - protocol.transport.encryption_session.user_hash, + protocol._transport._encryption_session.local_seed, + protocol._transport._encryption_session.remote_seed, + protocol._transport._encryption_session.user_hash, ) seq = params.get("seq") encryption_session._seq = seq - 1