diff --git a/src/snagrecover/protocols/fastboot.py b/src/snagrecover/protocols/fastboot.py index 45cb3d9..99a9fe8 100644 --- a/src/snagrecover/protocols/fastboot.py +++ b/src/snagrecover/protocols/fastboot.py @@ -21,6 +21,8 @@ import usb import time import tempfile +from typing import Optional, Union + from snagrecover import utils from snagflash.android_sparse_file.utils import split @@ -35,10 +37,16 @@ for more information on fastboot support in U-Boot. """ +FASTBOOT_UNSUPPORTED_CMD_RESPONSE = b"Unsupported command" +FASTBOOT_UNRECOGNIZED_CMD_RESPONSE = b"unrecognized command" + +CHECK_OEM_RUN_CMD_SUPPORT = "oem run:version\x00" + class FastbootError(Exception): - def __init__(self, message): + def __init__(self, message, data=None): self.message = message + self.data = data super().__init__(self.message) def __str__(self): @@ -86,40 +94,43 @@ def __init__(self, dev: usb.core.Device, timeout: int = 10000): self.max_size = MAX_LIBUSB_TRANSFER_SIZE - def cmd(self, packet: bytes): - self.dev.write(self.ep_out, packet, timeout=self.timeout) - status = "" - t0 = time.time() - while time.time() - t0 < 10 * self.timeout: - ret = self.dev.read(self.ep_in, 256, timeout=self.timeout) - status = bytes(ret[:4]) - if status == b"INFO": - logger.debug(f"(bootloader) {bytes(ret[4:256])}") - elif status == b"TEXT": - logger.debug(f"(bootloader) {bytes(ret[4:256])}", end="") - elif status == b"FAIL": - raise FastbootError(f"Fastboot fail with message: {bytes(ret[4:256])}") - elif status == b"OKAY": - logger.debug("fastboot OKAY") - return bytes(ret[4:]) - elif status == b"DATA": - length = int("0x" + (bytes(ret[4:12]).decode("ascii")), base=16) - logger.debug(f"fastboot DATA length: {length}") - return length - raise FastbootError("Timeout while completing fastboot transaction") + # The support of OEM commands is depending on the configuration + # u-boot was built with, so they need to be probed at runtime. + # The command handler for oem run is actually ucmd in the sources, + # therefore it's safe to use this as fallback. + self.oem_run_basecmd = ( + "oem run" if self._is_cmd_supported(CHECK_OEM_RUN_CMD_SUPPORT) else "UCmd" + ) - def response(self): + def _is_cmd_supported(self, cmd: str) -> bool: + try: + self.cmd(cmd) + except FastbootError as e: + if e.data and FASTBOOT_UNSUPPORTED_CMD_RESPONSE in e.data: + return False + return True + + def cmd( + self, packet: Optional[bytes] = None, loglevel=logging.DEBUG + ) -> Union[bytes, int]: + if packet is not None: + self.dev.write(self.ep_out, packet, timeout=self.timeout) t0 = time.time() while time.time() - t0 < 10 * self.timeout: ret = self.dev.read(self.ep_in, 256, timeout=self.timeout) status = bytes(ret[:4]) + data = bytes(ret[4:256]) if status in [b"INFO", b"TEXT"]: - logger.info(f"(bootloader) {bytes(ret[4:256])}", end="") + logger.log(loglevel, f"(bootloader) {data}", end="") elif status == b"FAIL": - raise FastbootError(f"Fastboot fail with message: {bytes(ret[4:256])}") + raise FastbootError(f"Fastboot fail with message: {data}", data) elif status == b"OKAY": - logger.info("fastboot OKAY") - return bytes(ret[4:]) + logger.log(loglevel, "fastboot OKAY") + return data + elif packet is not None and status == b"DATA": + length = int("0x" + (data.decode("ascii")), base=16) + logger.log(loglevel, f"fastboot DATA length: {length}") + return length raise FastbootError("Timeout while completing fastboot transaction") def getvar(self, var: str): @@ -133,7 +144,7 @@ def send(self, blob: bytes, padding: int = 0): self.cmd(packet) for chunk in utils.dnload_iter(blob + b"\x00" * padding, self.max_size): self.dev.write(self.ep_out, chunk, timeout=self.timeout) - self.response() + self.cmd(loglevel=logging.INFO) def download(self, path: str, padding: int = 0): with open(path, "rb") as file: @@ -191,7 +202,7 @@ def oem_run(self, cmd: str): """ Execute an arbitrary U-Boot command """ - packet = f"oem run:{cmd}\x00" + packet = f"{self.oem_run_basecmd}:{cmd}\x00" self.cmd(packet) def oem_format(self): diff --git a/tests.py b/tests.py index cd98c87..8fb7097 100644 --- a/tests.py +++ b/tests.py @@ -34,6 +34,7 @@ # Skip USB tests if no backend is available (e.g., in Windows CI env) print(f"Skipping USB tests: {e}") +print("Executing unit tests") unit_tests = unittest.TestLoader().discover("tests", "*.py") unit_runner = unittest.TextTestRunner() diff --git a/tests/protocols_fastboot.py b/tests/protocols_fastboot.py new file mode 100644 index 0000000..e37a122 --- /dev/null +++ b/tests/protocols_fastboot.py @@ -0,0 +1,113 @@ +import unittest +from unittest.mock import MagicMock +from enum import Enum, auto + +from snagrecover.protocols.fastboot import Fastboot, FastbootError + +DUMMY_CMD: str = "test_cmd" +DUMMY_DATA: bytes = b"test_data" + + +class ResponseType(Enum): + INFO = auto() + TEXT = auto() + FAIL = auto() + OKAY = auto() + DATA = auto() + TIMEOUT = auto() + + +class TestFastboot(unittest.TestCase): + @staticmethod + def _get_usb_device_mock() -> MagicMock: + # Mock endpoint with bulk IN/OUT attributes + mock_ep_in = MagicMock() + mock_ep_in.bmAttributes = 0x02 # ENDPOINT_TYPE_BULK + mock_ep_in.bEndpointAddress = 0x81 # ENDPOINT_IN + mock_ep_out = MagicMock() + mock_ep_out.bmAttributes = 0x02 # ENDPOINT_TYPE_BULK + mock_ep_out.bEndpointAddress = 0x01 # ENDPOINT_OUT + + # Mock interface with endpoints + mock_intf = MagicMock() + mock_intf.endpoints.return_value = [mock_ep_in, mock_ep_out] + + # Mock configuration with interfaces + mock_cfg = MagicMock() + mock_cfg.interfaces.return_value = [mock_intf] + + mock_device = MagicMock() + mock_device.get_active_configuration.return_value = mock_cfg + + return mock_device + + def _setup_fastboot(self, has_oem_run: bool = True) -> Fastboot: + self.mock_device = self._get_usb_device_mock() + self.mock_device.read.side_effect = [ + b"FAILunrecognized command\x00" + if has_oem_run + else b"FAILUnsupported command\x00" + ] + self.fastboot = Fastboot(self.mock_device) + self.assert_device_write("oem run:version\x00") + self.mock_device.reset_mock() + + def setUp(self) -> None: + self._setup_fastboot() + + def assert_device_write(self, expected_cmd: str) -> None: + self.mock_device.write.assert_called_once_with( + self.fastboot.ep_out, expected_cmd, timeout=self.fastboot.timeout + ) + + def expect_device_response( + self, response_type: ResponseType, data: bytes = b"" + ) -> None: + read_value = response_type.name.encode() + data + if response_type == ResponseType.TIMEOUT: + self.mock_device.read.side_effect = TimeoutError() + return + + self.mock_device.read.side_effect = [read_value] + + # --- Generic cmd tests --- + + def test_cmd_okay(self) -> None: + self.expect_device_response(ResponseType.OKAY, DUMMY_DATA) + result = self.fastboot.cmd(DUMMY_CMD) + self.assert_device_write(DUMMY_CMD) + self.assertEqual(result, DUMMY_DATA) + + def test_cmd_fail(self) -> None: + self.expect_device_response(ResponseType.FAIL, DUMMY_DATA) + with self.assertRaises(FastbootError): + self.fastboot.cmd(DUMMY_CMD) + self.assert_device_write(DUMMY_CMD) + + def test_cmd_info_or_text(self) -> None: + for response in [ResponseType.INFO, ResponseType.TEXT]: + with self.subTest(response=response): + self.mock_device.reset_mock() + self.mock_device.read.side_effect = [ + response.name.encode() + DUMMY_DATA, + ResponseType.OKAY.name.encode() + b" waited for OKAY", + ] + self.fastboot.cmd(DUMMY_CMD) + self.assert_device_write(DUMMY_CMD) + + # --- oem_run --- + + def test_oem_run_accepted(self) -> None: + subcommand = "test_oem_subcommand" + self.expect_device_response(ResponseType.OKAY) + self.fastboot.oem_run(subcommand) + expected_cmd = f"oem run:{subcommand}\x00" + self.assert_device_write(expected_cmd) + + def test_oem_run_unavailable_fallback_to_ucmd(self) -> None: + self._setup_fastboot(has_oem_run=False) + + subcommand = "test_oem_subcommand" + self.expect_device_response(ResponseType.OKAY, b" executed via UCmd") + self.fastboot.oem_run(subcommand) + self.assert_device_write(f"UCmd:{subcommand}\x00")