Fix reliability issues from code review, add error-path tests

Critical fixes:
- Separate notification and command connections to eliminate dual-reader
  race condition on the TCL RPC stream (C-1)
- Fix _get_or_create_loop() swallowing its own RuntimeError, causing
  deadlock when sync API called from async context (C-2)
- Add bounds checking to config string parser (C-3)
- Clean up OpenOCD subprocess on connection failure in Session.start (H-1)

Defense in depth:
- Add MAX_RESPONSE_SIZE (10MB) guard against unbounded buffer growth
- Preserve bytes after separator in _read_until_separator remainder buffer
- Set notification_failed flag when listener crashes, warn on next send
- Standardize error detection to case-insensitive across all modules
- Escape TCL special characters in RTT channelwrite to prevent injection
- Redirect OpenOCD stdout to DEVNULL to prevent pipe buffer deadlock
- Run SVD XML parsing in asyncio.to_thread to avoid blocking event loop

Consistency:
- Cache SyncSession subsystem wrappers (match async Session pattern)
- Make DecodedRegister frozen (match all other dataclasses)
- Add py.typed marker for PEP 561 type checker support
- Accept list[str] config in OpenOCDProcess.start for paths with spaces

Tests:
- Add 50 error-path tests covering connection, target, memory, register,
  flash, breakpoint, session, process, and notification failure modes
This commit is contained in:
Ryan Malloy 2026-02-12 18:52:38 -07:00
parent 7e1eac5e2d
commit bc7cb77ec4
13 changed files with 833 additions and 59 deletions

View File

@ -4,6 +4,13 @@ OpenOCD's TCL RPC uses a simple framing protocol:
- Client sends: command_string + \\x1a - Client sends: command_string + \\x1a
- Server replies: response_string + \\x1a - Server replies: response_string + \\x1a
The \\x1a (ASCII SUB / Ctrl-Z) byte acts as an unambiguous delimiter. The \\x1a (ASCII SUB / Ctrl-Z) byte acts as an unambiguous delimiter.
Notifications use a **separate connection** to avoid dual-reader race
conditions on the command stream. When ``enable_notifications()`` is
called, a second TCP connection is opened to the same host:port. That
connection sends ``tcl_notifications on`` and then exclusively reads
unsolicited events, leaving the primary connection free for
request/response commands.
""" """
from __future__ import annotations from __future__ import annotations
@ -21,21 +28,34 @@ log = logging.getLogger(__name__)
SEPARATOR = b"\x1a" SEPARATOR = b"\x1a"
DEFAULT_TIMEOUT = 10.0 DEFAULT_TIMEOUT = 10.0
MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10 MB — guard against runaway reads
class TclRpcConnection(Connection): class TclRpcConnection(Connection):
"""Async TCP client speaking OpenOCD's TCL RPC protocol.""" """Async TCP client speaking OpenOCD's TCL RPC protocol.
The command connection and the notification connection are kept on
**separate sockets** so that unsolicited events never corrupt
the request/response stream.
"""
def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None: def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None:
# Primary command connection
self._reader: asyncio.StreamReader | None = None self._reader: asyncio.StreamReader | None = None
self._writer: asyncio.StreamWriter | None = None self._writer: asyncio.StreamWriter | None = None
self._timeout = timeout self._timeout = timeout
self._notification_callbacks: list[Callable[[str], None]] = []
self._notification_task: asyncio.Task[None] | None = None
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._remainder = bytearray() # leftover bytes after separator
self._host: str = "" self._host: str = ""
self._port: int = 0 self._port: int = 0
# Notification connection (separate socket)
self._notif_reader: asyncio.StreamReader | None = None
self._notif_writer: asyncio.StreamWriter | None = None
self._notification_callbacks: list[Callable[[str], None]] = []
self._notification_task: asyncio.Task[None] | None = None
self._notification_failed: bool = False
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Connection lifecycle # Connection lifecycle
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -59,18 +79,28 @@ class TclRpcConnection(Connection):
log.debug("Connected to OpenOCD TCL RPC at %s:%d", host, port) log.debug("Connected to OpenOCD TCL RPC at %s:%d", host, port)
async def close(self) -> None: async def close(self) -> None:
# Tear down notification connection first
if self._notification_task and not self._notification_task.done(): if self._notification_task and not self._notification_task.done():
self._notification_task.cancel() self._notification_task.cancel()
with contextlib.suppress(asyncio.CancelledError): with contextlib.suppress(asyncio.CancelledError):
await self._notification_task await self._notification_task
self._notification_task = None self._notification_task = None
if self._notif_writer:
self._notif_writer.close()
with contextlib.suppress(OSError):
await self._notif_writer.wait_closed()
self._notif_writer = None
self._notif_reader = None
# Tear down primary command connection
if self._writer: if self._writer:
self._writer.close() self._writer.close()
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
await self._writer.wait_closed() await self._writer.wait_closed()
self._writer = None self._writer = None
self._reader = None self._reader = None
self._remainder.clear()
log.debug("TCL RPC connection closed") log.debug("TCL RPC connection closed")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -86,6 +116,9 @@ class TclRpcConnection(Connection):
if not self._writer or not self._reader: if not self._writer or not self._reader:
raise ConnectionError("Not connected — call connect() first") raise ConnectionError("Not connected — call connect() first")
if self._notification_failed:
log.warning("Notification listener has stopped — events may be missed")
async with self._lock: async with self._lock:
payload = command.encode("utf-8") + SEPARATOR payload = command.encode("utf-8") + SEPARATOR
self._writer.write(payload) self._writer.write(payload)
@ -107,39 +140,100 @@ class TclRpcConnection(Connection):
return response return response
async def _read_until_separator(self) -> bytes: async def _read_until_separator(self) -> bytes:
"""Read from the stream until the \\x1a separator is found.""" """Read from the command stream until the \\x1a separator is found.
Preserves any bytes received after the separator for the next call.
Raises ``ConnectionError`` if the response exceeds ``MAX_RESPONSE_SIZE``.
"""
assert self._reader is not None assert self._reader is not None
buf = bytearray() buf = self._remainder
self._remainder = bytearray()
# Check if remainder already contains a complete response
idx = buf.find(SEPARATOR)
if idx != -1:
result = bytes(buf[:idx])
self._remainder = bytearray(buf[idx + 1 :])
return result
while True: while True:
chunk = await self._reader.read(4096) chunk = await self._reader.read(4096)
if not chunk: if not chunk:
raise ConnectionError("OpenOCD closed the connection") raise ConnectionError("OpenOCD closed the connection")
buf.extend(chunk) buf.extend(chunk)
if len(buf) > MAX_RESPONSE_SIZE:
raise ConnectionError(
f"Response exceeded {MAX_RESPONSE_SIZE} bytes without separator — "
"is this an OpenOCD TCL RPC port?"
)
idx = buf.find(SEPARATOR) idx = buf.find(SEPARATOR)
if idx != -1: if idx != -1:
return bytes(buf[:idx]) result = bytes(buf[:idx])
self._remainder = bytearray(buf[idx + 1 :])
return result
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Notifications (async events from OpenOCD) # Notifications (separate connection)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def enable_notifications(self) -> None: async def enable_notifications(self) -> None:
"""Enable TCL event notifications and start the listener loop. """Open a dedicated notification connection and start the listener.
Sends ``tcl_notifications on`` which causes OpenOCD to push A **separate TCP connection** to the same OpenOCD instance is
target-state-change events over the same socket. used for notifications. This avoids the dual-reader race
condition that would occur if notifications and command
responses shared the same stream.
""" """
await self.send("tcl_notifications on") if not self._host:
raise ConnectionError("Not connected — call connect() first")
try:
self._notif_reader, self._notif_writer = await asyncio.wait_for(
asyncio.open_connection(self._host, self._port),
timeout=self._timeout,
)
except OSError as exc:
raise ConnectionError(
f"Cannot open notification connection to {self._host}:{self._port}: {exc}"
) from exc
except TimeoutError as exc:
raise OcdTimeoutError(
f"Timed out opening notification connection to {self._host}:{self._port}"
) from exc
# Enable notifications on the dedicated connection
enable_cmd = b"tcl_notifications on" + SEPARATOR
self._notif_writer.write(enable_cmd)
await self._notif_writer.drain()
# Read and discard the acknowledgement
ack_buf = bytearray()
while True:
chunk = await asyncio.wait_for(
self._notif_reader.read(4096), timeout=self._timeout
)
if not chunk:
raise ConnectionError(
"Notification connection closed during setup"
)
ack_buf.extend(chunk)
if ack_buf.find(SEPARATOR) != -1:
break
log.debug("Notification connection established to %s:%d", self._host, self._port)
self._notification_failed = False
self._notification_task = asyncio.create_task(self._notification_loop()) self._notification_task = asyncio.create_task(self._notification_loop())
async def _notification_loop(self) -> None: async def _notification_loop(self) -> None:
"""Background task that reads unsolicited notifications.""" """Background task that reads unsolicited notifications from
assert self._reader is not None the dedicated notification connection."""
assert self._notif_reader is not None
buf = bytearray() buf = bytearray()
try: try:
while True: while True:
chunk = await self._reader.read(4096) chunk = await self._notif_reader.read(4096)
if not chunk: if not chunk:
log.warning("Notification connection closed by OpenOCD")
break break
buf.extend(chunk) buf.extend(chunk)
while True: while True:
@ -158,6 +252,8 @@ class TclRpcConnection(Connection):
return return
except Exception: except Exception:
log.exception("Notification loop crashed") log.exception("Notification loop crashed")
finally:
self._notification_failed = True
def on_notification(self, callback: Callable[[str], None]) -> None: def on_notification(self, callback: Callable[[str], None]) -> None:
self._notification_callbacks.append(callback) self._notification_callbacks.append(callback)

View File

@ -19,6 +19,7 @@ log = logging.getLogger(__name__)
PROMPT = b"> " PROMPT = b"> "
DEFAULT_TIMEOUT = 10.0 DEFAULT_TIMEOUT = 10.0
MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10 MB
class TelnetConnection(Connection): class TelnetConnection(Connection):
@ -90,6 +91,10 @@ class TelnetConnection(Connection):
if not chunk: if not chunk:
raise ConnectionError("OpenOCD closed the connection") raise ConnectionError("OpenOCD closed the connection")
buf.extend(chunk) buf.extend(chunk)
if len(buf) > MAX_RESPONSE_SIZE:
raise ConnectionError(
f"Response exceeded {MAX_RESPONSE_SIZE} bytes without prompt"
)
if buf.endswith(PROMPT): if buf.endswith(PROMPT):
return bytes(buf[: -len(PROMPT)]) return bytes(buf[: -len(PROMPT)])

View File

@ -48,5 +48,5 @@ async def xsvf(conn: Connection, tap: str, path: Path) -> None:
def _check_error(response: str, command: str) -> None: def _check_error(response: str, command: str) -> None:
if "Error" in response or "error" in response.split("\n")[0]: if "error" in response.lower():
raise JTAGError(f"{command} failed: {response.strip()}") raise JTAGError(f"{command} failed: {response.strip()}")

View File

@ -32,7 +32,7 @@ _CHAIN_ROW_RE = re.compile(
async def scan_chain(conn: Connection) -> list[TAPInfo]: async def scan_chain(conn: Connection) -> list[TAPInfo]:
"""Query the JTAG scan chain and return a list of discovered TAPs.""" """Query the JTAG scan chain and return a list of discovered TAPs."""
resp = await conn.send("scan_chain") resp = await conn.send("scan_chain")
if "Error" in resp: if "error" in resp.lower():
raise JTAGError(f"scan_chain failed: {resp.strip()}") raise JTAGError(f"scan_chain failed: {resp.strip()}")
return _parse_scan_chain(resp) return _parse_scan_chain(resp)
@ -57,7 +57,7 @@ async def new_tap(
if expected_id is not None: if expected_id is not None:
parts.extend(["-expected-id", f"0x{expected_id:08x}"]) parts.extend(["-expected-id", f"0x{expected_id:08x}"])
resp = await conn.send(" ".join(parts)) resp = await conn.send(" ".join(parts))
if "Error" in resp: if "error" in resp.lower():
raise JTAGError(f"newtap failed: {resp.strip()}") raise JTAGError(f"newtap failed: {resp.strip()}")

View File

@ -54,5 +54,5 @@ async def runtest(conn: Connection, cycles: int) -> None:
def _check_error(response: str, command: str) -> None: def _check_error(response: str, command: str) -> None:
"""Raise JTAGError if OpenOCD reported an error.""" """Raise JTAGError if OpenOCD reported an error."""
if "Error" in response or "error" in response.split("\n")[0]: if "error" in response.lower():
raise JTAGError(f"{command} failed: {response.strip()}") raise JTAGError(f"{command} failed: {response.strip()}")

View File

@ -22,5 +22,5 @@ async def pathmove(conn: Connection, states: list[JTAGState]) -> None:
def _check_error(response: str, command: str) -> None: def _check_error(response: str, command: str) -> None:
if "Error" in response or "error" in response.split("\n")[0]: if "error" in response.lower():
raise JTAGError(f"{command} failed: {response.strip()}") raise JTAGError(f"{command} failed: {response.strip()}")

View File

@ -40,7 +40,7 @@ class OpenOCDProcess:
async def start( async def start(
self, self,
config: str, config: str | list[str],
extra_args: list[str] | None = None, extra_args: list[str] | None = None,
tcl_port: int = DEFAULT_TCL_PORT, tcl_port: int = DEFAULT_TCL_PORT,
openocd_bin: str | None = None, openocd_bin: str | None = None,
@ -48,9 +48,10 @@ class OpenOCDProcess:
"""Start OpenOCD with the given configuration. """Start OpenOCD with the given configuration.
Args: Args:
config: Config file path or inline ``-f`` / ``-c`` arguments. config: Config file path, inline ``-f`` / ``-c`` arguments string,
Multiple files can be separated by spaces with ``-f`` prefixes, or a list of arguments (preferred for paths with spaces).
e.g. ``"interface/cmsis-dap.cfg -f target/stm32f1x.cfg"``. String form: ``"interface/cmsis-dap.cfg -f target/stm32f1x.cfg"``
List form: ``["-f", "my config/board.cfg", "-f", "target/stm32f1x.cfg"]``
extra_args: Additional CLI arguments. extra_args: Additional CLI arguments.
tcl_port: TCL RPC port (default 6666). tcl_port: TCL RPC port (default 6666).
openocd_bin: Path to OpenOCD binary (auto-detected if None). openocd_bin: Path to OpenOCD binary (auto-detected if None).
@ -63,17 +64,27 @@ class OpenOCDProcess:
) )
args = [binary] args = [binary]
# Parse the config string — support both bare paths and -f/-c flags
config_parts = config.split() # Accept either a pre-split list or a string to parse
i = 0 if isinstance(config, list):
while i < len(config_parts): args.extend(config)
part = config_parts[i] else:
if part in ("-f", "-c"): config_parts = config.split()
args.extend([part, config_parts[i + 1]]) if not config_parts:
i += 2 raise ProcessError("Empty config string")
else: i = 0
args.extend(["-f", part]) while i < len(config_parts):
i += 1 part = config_parts[i]
if part in ("-f", "-c"):
if i + 1 >= len(config_parts):
raise ProcessError(
f"Config flag '{part}' requires an argument"
)
args.extend([part, config_parts[i + 1]])
i += 2
else:
args.extend(["-f", part])
i += 1
args.extend(["-c", f"tcl_port {tcl_port}"]) args.extend(["-c", f"tcl_port {tcl_port}"])
@ -84,7 +95,7 @@ class OpenOCDProcess:
try: try:
self._proc = await asyncio.create_subprocess_exec( self._proc = await asyncio.create_subprocess_exec(
*args, *args,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) )
except FileNotFoundError as exc: except FileNotFoundError as exc:

0
src/openocd/py.typed Normal file
View File

View File

@ -124,7 +124,14 @@ class RTTManager:
Raises: Raises:
OpenOCDError: If the write command fails. OpenOCDError: If the write command fails.
""" """
cmd = f'rtt channelwrite {channel} "{data}"' # Escape TCL special characters to prevent injection
escaped = (
data.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("[", "\\[")
.replace("$", "\\$")
)
cmd = f'rtt channelwrite {channel} "{escaped}"'
response = await self._conn.send(cmd) response = await self._conn.send(cmd)
_check_rtt_response(response, cmd) _check_rtt_response(response, cmd)

View File

@ -74,8 +74,12 @@ class Session:
) )
await proc.wait_ready(timeout=timeout) await proc.wait_ready(timeout=timeout)
conn = TclRpcConnection(timeout=timeout) try:
await conn.connect("localhost", tcl_port) conn = TclRpcConnection(timeout=timeout)
await conn.connect("localhost", tcl_port)
except Exception:
await proc.stop()
raise
return cls(connection=conn, process=proc) return cls(connection=conn, process=proc)
@ -229,6 +233,13 @@ class SyncSession:
def __init__(self, session: Session, loop: asyncio.AbstractEventLoop) -> None: def __init__(self, session: Session, loop: asyncio.AbstractEventLoop) -> None:
self._session = session self._session = session
self._loop = loop self._loop = loop
self._target: SyncTarget | None = None
self._memory: SyncMemory | None = None
self._registers: SyncRegisters | None = None
self._flash: SyncFlash | None = None
self._jtag: SyncJTAGController | None = None
self._breakpoints: SyncBreakpointManager | None = None
self._svd: SyncSVDManager | None = None
def __enter__(self) -> SyncSession: def __enter__(self) -> SyncSession:
return self return self
@ -241,38 +252,52 @@ class SyncSession:
@property @property
def target(self) -> SyncTarget: def target(self) -> SyncTarget:
from openocd.target import SyncTarget if self._target is None:
return SyncTarget(self._session.target, self._loop) from openocd.target import SyncTarget
self._target = SyncTarget(self._session.target, self._loop)
return self._target
@property @property
def memory(self) -> SyncMemory: def memory(self) -> SyncMemory:
from openocd.memory import SyncMemory if self._memory is None:
return SyncMemory(self._session.memory, self._loop) from openocd.memory import SyncMemory
self._memory = SyncMemory(self._session.memory, self._loop)
return self._memory
@property @property
def registers(self) -> SyncRegisters: def registers(self) -> SyncRegisters:
from openocd.registers import SyncRegisters if self._registers is None:
return SyncRegisters(self._session.registers, self._loop) from openocd.registers import SyncRegisters
self._registers = SyncRegisters(self._session.registers, self._loop)
return self._registers
@property @property
def flash(self) -> SyncFlash: def flash(self) -> SyncFlash:
from openocd.flash import SyncFlash if self._flash is None:
return SyncFlash(self._session.flash, self._loop) from openocd.flash import SyncFlash
self._flash = SyncFlash(self._session.flash, self._loop)
return self._flash
@property @property
def jtag(self) -> SyncJTAGController: def jtag(self) -> SyncJTAGController:
from openocd.jtag import SyncJTAGController if self._jtag is None:
return SyncJTAGController(self._session.jtag, self._loop) from openocd.jtag import SyncJTAGController
self._jtag = SyncJTAGController(self._session.jtag, self._loop)
return self._jtag
@property @property
def breakpoints(self) -> SyncBreakpointManager: def breakpoints(self) -> SyncBreakpointManager:
from openocd.breakpoints import SyncBreakpointManager if self._breakpoints is None:
return SyncBreakpointManager(self._session.breakpoints, self._loop) from openocd.breakpoints import SyncBreakpointManager
self._breakpoints = SyncBreakpointManager(self._session.breakpoints, self._loop)
return self._breakpoints
@property @property
def svd(self) -> SyncSVDManager: def svd(self) -> SyncSVDManager:
from openocd.svd import SyncSVDManager if self._svd is None:
return SyncSVDManager(self._session.svd, self._loop) from openocd.svd import SyncSVDManager
self._svd = SyncSVDManager(self._session.svd, self._loop)
return self._svd
# ====================================================================== # ======================================================================
@ -280,16 +305,20 @@ class SyncSession:
# ====================================================================== # ======================================================================
def _get_or_create_loop() -> asyncio.AbstractEventLoop: def _get_or_create_loop() -> asyncio.AbstractEventLoop:
"""Get the running event loop, or create a new one if there isn't one.""" """Get or create an event loop for synchronous usage.
Raises RuntimeError if called from within an already-running async
context (where ``run_until_complete`` would deadlock).
"""
try: try:
loop = asyncio.get_running_loop() asyncio.get_running_loop()
# If we're already in an async context we can't use run_until_complete except RuntimeError:
pass # No running loop — this is the expected path for sync usage
else:
raise RuntimeError( raise RuntimeError(
"Cannot use sync API from an async context. " "Cannot use sync API from an async context. "
"Use the async Session.start()/connect() instead." "Use the async Session.start()/connect() instead."
) )
except RuntimeError:
pass
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_closed(): if loop.is_closed():

View File

@ -51,7 +51,7 @@ class SVDManager:
Raises: Raises:
SVDError: If the file is missing or unparseable. SVDError: If the file is missing or unparseable.
""" """
self._parser.load(svd_path) await asyncio.to_thread(self._parser.load, svd_path)
def list_peripherals(self) -> list[str]: def list_peripherals(self) -> list[str]:
"""Return sorted peripheral names from the loaded SVD. """Return sorted peripheral names from the loaded SVD.

View File

@ -127,7 +127,7 @@ class BitField:
description: str description: str
@dataclass @dataclass(frozen=True)
class DecodedRegister: class DecodedRegister:
"""A register value decoded into named bitfields via SVD.""" """A register value decoded into named bitfields via SVD."""

626
tests/test_error_paths.py Normal file
View File

@ -0,0 +1,626 @@
"""Error-path tests for openocd-python.
Exercises every error condition and exception branch across the
connection, target, memory, register, flash, breakpoint, session,
and process subsystems.
Each test configures a mock server to return error responses (or
misbehave at the protocol level) and asserts that the correct
exception type is raised with a meaningful message.
"""
from __future__ import annotations
import asyncio
import pytest
from openocd.breakpoints import BreakpointError, BreakpointManager
from openocd.connection.tcl_rpc import TclRpcConnection
from openocd.errors import (
ConnectionError,
FlashError,
ProcessError,
TargetError,
TargetNotHaltedError,
TimeoutError,
)
from openocd.flash import Flash
from openocd.memory import Memory
from openocd.process import OpenOCDProcess
from openocd.registers import Registers
from openocd.session import Session, _get_or_create_loop
from openocd.target import Target
from tests.mock_server import MockOpenOCDServer
# ======================================================================
# Helpers: error-returning mock servers
# ======================================================================
@pytest.fixture
async def error_server():
"""A MockOpenOCDServer pre-wired to return error strings.
The default response table is left intact so that "targets" and
similar plumbing commands still work. Individual tests prepend
error-producing routes via ``server.add_response()``.
"""
server = MockOpenOCDServer()
await server.start()
yield server
await server.stop()
@pytest.fixture
async def error_conn(error_server):
"""A TclRpcConnection wired to the error mock server."""
host, port = error_server.address
conn = TclRpcConnection(timeout=5.0)
await conn.connect(host, port)
yield conn
await conn.close()
# ======================================================================
# 1. Connection error paths
# ======================================================================
class TestConnectionErrors:
"""Errors at the transport / framing layer."""
async def test_send_on_closed_connection(self):
"""send() after close() raises ConnectionError."""
server = MockOpenOCDServer()
await server.start()
host, port = server.address
conn = TclRpcConnection(timeout=5.0)
await conn.connect(host, port)
await conn.close()
with pytest.raises(ConnectionError, match="Not connected"):
await conn.send("targets")
await server.stop()
async def test_send_before_connect(self):
"""send() without a prior connect() raises ConnectionError."""
conn = TclRpcConnection()
with pytest.raises(ConnectionError, match="Not connected"):
await conn.send("targets")
async def test_timeout_when_server_never_responds(self):
"""A server that reads but never sends \\x1a triggers TimeoutError."""
hang_event = asyncio.Event()
async def _black_hole(reader, writer):
await reader.read(4096)
# never send a response -- wait until test signals us to stop
try:
await hang_event.wait()
except asyncio.CancelledError:
pass
finally:
writer.close()
srv = await asyncio.start_server(_black_hole, "127.0.0.1", 0)
await srv.start_serving()
host, port = srv.sockets[0].getsockname()[:2]
conn = TclRpcConnection(timeout=0.3)
await conn.connect(host, port)
with pytest.raises(TimeoutError):
await conn.send("targets")
await conn.close()
hang_event.set()
srv.close()
await srv.wait_closed()
async def test_server_closes_connection_mid_stream(self):
"""Server closing the socket without a separator raises ConnectionError."""
async def _close_immediately(reader, writer):
await reader.read(4096)
# send partial data with no separator then close
writer.write(b"partial response")
await writer.drain()
writer.close()
await writer.wait_closed()
srv = await asyncio.start_server(_close_immediately, "127.0.0.1", 0)
await srv.start_serving()
host, port = srv.sockets[0].getsockname()[:2]
conn = TclRpcConnection(timeout=2.0)
await conn.connect(host, port)
with pytest.raises(ConnectionError, match="closed the connection"):
await conn.send("targets")
await conn.close()
srv.close()
await srv.wait_closed()
async def test_bounded_read_rejects_oversized_response(self):
"""Response exceeding MAX_RESPONSE_SIZE without separator raises ConnectionError."""
async def _flood(reader, writer):
await reader.read(4096)
# send a lot of data with no separator
chunk = b"A" * 65536
try:
while True:
writer.write(chunk)
await writer.drain()
except (BrokenPipeError, ConnectionResetError):
pass
srv = await asyncio.start_server(_flood, "127.0.0.1", 0)
await srv.start_serving()
host, port = srv.sockets[0].getsockname()[:2]
conn = TclRpcConnection(timeout=10.0)
await conn.connect(host, port)
with pytest.raises(ConnectionError, match="exceeded"):
await conn.send("targets")
await conn.close()
srv.close()
await srv.wait_closed()
async def test_connect_refused(self):
"""Connecting to a port with nothing listening raises ConnectionError."""
conn = TclRpcConnection(timeout=1.0)
with pytest.raises(ConnectionError):
await conn.connect("127.0.0.1", 1)
# ======================================================================
# 2. Target error paths
# ======================================================================
class TestTargetErrors:
"""Error conditions from the Target subsystem."""
async def test_halt_error_response(self, error_server, error_conn):
"""halt() with an error response raises TargetError."""
error_server.add_response(r"^halt$", "error: target not responding")
target = Target(error_conn)
with pytest.raises(TargetError, match="halt failed"):
await target.halt()
async def test_halt_already_halted_is_not_error(self, error_server, error_conn):
"""halt() when response says 'already halted' should NOT raise."""
error_server.add_response(r"^halt$", "error: target already halted")
target = Target(error_conn)
# "already halted" is a benign condition -- halt() checks for it
state = await target.halt()
assert state.state in ("halted", "unknown")
async def test_resume_error_response(self, error_server, error_conn):
"""resume() with an error response raises TargetError."""
error_server.add_response(r"^resume", "error: cannot resume target")
target = Target(error_conn)
with pytest.raises(TargetError, match="resume failed"):
await target.resume()
async def test_step_error_response(self, error_server, error_conn):
"""step() with an error response raises TargetError."""
error_server.add_response(r"^step", "error: step failed on target")
target = Target(error_conn)
with pytest.raises(TargetError, match="step failed"):
await target.step()
async def test_wait_halt_timeout(self, error_server, error_conn):
"""wait_halt receiving 'timed out' raises TimeoutError."""
error_server.add_response(r"^wait_halt", "timed out while waiting for target")
target = Target(error_conn)
with pytest.raises(TimeoutError, match="did not halt"):
await target.wait_halt(timeout_ms=100)
async def test_wait_halt_time_out_variant(self, error_server, error_conn):
"""wait_halt receiving 'time out' (two words) also raises TimeoutError."""
error_server.add_response(r"^wait_halt", "time out waiting for halt")
target = Target(error_conn)
with pytest.raises(TimeoutError, match="did not halt"):
await target.wait_halt(timeout_ms=100)
async def test_wait_halt_generic_error(self, error_server, error_conn):
"""wait_halt with a generic error (not timeout) raises TargetError."""
error_server.add_response(r"^wait_halt", "error: target communication failure")
target = Target(error_conn)
with pytest.raises(TargetError, match="wait_halt failed"):
await target.wait_halt(timeout_ms=100)
async def test_state_unexpected_format(self, error_server, error_conn):
"""targets returning garbage still produces a TargetState with 'unknown'."""
error_server.add_response(r"^targets$", "this is not valid target output")
# Also need to suppress the reg pc call that _parse_state makes
error_server.add_response(r"^reg\s+pc$", "no such register")
target = Target(error_conn)
state = await target.state()
assert state.name == "unknown"
assert state.state == "unknown"
assert state.current_pc is None
async def test_state_unrecognized_state_string(self, error_server, error_conn):
"""A target row with a bizarre state string normalizes to 'unknown'."""
weird_table = (
" TargetName Type Endian TapName State\n"
"-- ------------------ ---------- ------ ------------------ ------------\n"
" 0* stm32f1x.cpu cortex_m little stm32f1x.cpu exploding"
)
error_server.add_response(r"^targets$", weird_table)
target = Target(error_conn)
state = await target.state()
assert state.name == "stm32f1x.cpu"
assert state.state == "unknown"
assert state.current_pc is None
async def test_reset_error_response(self, error_server, error_conn):
"""reset() with an error response raises TargetError."""
error_server.add_response(r"^reset\s+", "error: reset failed, adapter not found")
target = Target(error_conn)
with pytest.raises(TargetError, match="reset failed"):
await target.reset("halt")
# ======================================================================
# 3. Memory error paths
# ======================================================================
class TestMemoryErrors:
"""Error conditions from the Memory subsystem."""
async def test_read_u32_target_not_halted(self, error_server, error_conn):
"""read_u32 with 'target not halted' in response raises TargetError."""
error_server.add_response(
r"^read_memory\s+", "error: target not halted"
)
mem = Memory(error_conn)
with pytest.raises(TargetError, match="read_memory failed"):
await mem.read_u32(0x20000000, 1)
async def test_write_u32_error_response(self, error_server, error_conn):
"""write_u32 with an error response raises TargetError."""
error_server.add_response(
r"^write_memory\s+", "error: target not halted"
)
mem = Memory(error_conn)
with pytest.raises(TargetError, match="write_memory failed"):
await mem.write_u32(0x20000000, 0xDEADBEEF)
async def test_read_u32_non_hex_tokens(self, error_server, error_conn):
"""read_memory returning non-hex garbage raises TargetError."""
error_server.add_response(
r"^read_memory\s+", "not_a_hex_value xyz !!!"
)
mem = Memory(error_conn)
with pytest.raises(TargetError, match="Cannot parse read_memory"):
await mem.read_u32(0x20000000, 1)
async def test_read_u8_error_response(self, error_server, error_conn):
"""read_u8 with an error response raises TargetError."""
error_server.add_response(
r"^read_memory\s+", "error: bus fault during memory read"
)
mem = Memory(error_conn)
with pytest.raises(TargetError, match="read_memory failed"):
await mem.read_u8(0xFFFFFFFF, 4)
async def test_write_bytes_error_response(self, error_server, error_conn):
"""write_bytes with an error response raises TargetError."""
error_server.add_response(
r"^write_memory\s+", "error: write access violation"
)
mem = Memory(error_conn)
with pytest.raises(TargetError, match="write_memory failed"):
await mem.write_bytes(0x00000000, b"\x01\x02\x03")
async def test_read_u16_error_response(self, error_server, error_conn):
"""read_u16 with an error response raises TargetError."""
error_server.add_response(
r"^read_memory\s+", "error: alignment fault"
)
mem = Memory(error_conn)
with pytest.raises(TargetError, match="read_memory failed"):
await mem.read_u16(0x20000001, 1)
# ======================================================================
# 4. Register error paths
# ======================================================================
class TestRegisterErrors:
"""Error conditions from the Registers subsystem."""
async def test_read_not_halted(self, error_server, error_conn):
"""read('pc') when target is not halted raises TargetNotHaltedError."""
error_server.add_response(
r"^reg\s+pc$", "target not halted"
)
regs = Registers(error_conn)
with pytest.raises(TargetNotHaltedError, match="halted"):
await regs.read("pc")
async def test_read_nonexistent_register(self, error_server, error_conn):
"""read('nonexistent') with unparseable response raises TargetError."""
error_server.add_response(
r"^reg\s+nonexistent$", "invalid command name \"nonexistent\""
)
regs = Registers(error_conn)
with pytest.raises(TargetError, match="Cannot parse register"):
await regs.read("nonexistent")
async def test_write_not_halted(self, error_server, error_conn):
"""write('pc', val) when target is not halted raises TargetNotHaltedError."""
error_server.add_response(
r"^reg\s+pc\s+0x", "target not halted"
)
regs = Registers(error_conn)
with pytest.raises(TargetNotHaltedError, match="halted"):
await regs.write("pc", 0x1234)
async def test_write_generic_error(self, error_server, error_conn):
"""write() with a non-halted-related error raises TargetError."""
error_server.add_response(
r"^reg\s+r0\s+0x", "error: register write failed"
)
regs = Registers(error_conn)
with pytest.raises(TargetError, match="reg write failed"):
await regs.write("r0", 0xDEAD)
async def test_read_all_not_halted(self, error_server, error_conn):
"""read_all() when target is not halted raises TargetNotHaltedError."""
error_server.add_response(r"^reg$", "target not halted")
regs = Registers(error_conn)
with pytest.raises(TargetNotHaltedError, match="halted"):
await regs.read_all()
async def test_read_many_partial_failure(self, error_server, error_conn):
"""read_many() should propagate the first register read failure."""
# pc succeeds, but sp returns not-halted
error_server.add_response(r"^reg\s+sp$", "target not halted")
regs = Registers(error_conn)
with pytest.raises(TargetNotHaltedError):
await regs.read_many(["pc", "sp"])
# ======================================================================
# 5. Flash error paths
# ======================================================================
class TestFlashErrors:
"""Error conditions from the Flash subsystem."""
async def test_banks_error(self, error_server, error_conn):
"""flash.banks() with an error response raises FlashError."""
error_server.add_response(r"^flash banks$", "error: no flash banks configured")
flash = Flash(error_conn)
with pytest.raises(FlashError, match="flash banks"):
await flash.banks()
async def test_info_error(self, error_server, error_conn):
"""flash.info() with an error response raises FlashError."""
error_server.add_response(r"^flash info\s+", "error: invalid bank number")
flash = Flash(error_conn)
with pytest.raises(FlashError, match="flash info"):
await flash.info(bank=99)
async def test_erase_sector_invalid_range(self, error_server, error_conn):
"""erase_sector with first > last raises FlashError locally."""
flash = Flash(error_conn)
with pytest.raises(FlashError, match="Invalid sector range"):
await flash.erase_sector(bank=0, first=10, last=5)
async def test_erase_sector_error_response(self, error_server, error_conn):
"""erase_sector with error from server raises FlashError."""
error_server.add_response(r"^flash erase_sector\s+", "error: erase failed")
flash = Flash(error_conn)
with pytest.raises(FlashError, match="flash erase_sector"):
await flash.erase_sector(bank=0, first=0, last=3)
async def test_write_image_error(self, error_server, error_conn):
"""write_image with error from server raises FlashError."""
error_server.add_response(
r"^flash write_image\s+", "error: flash write failed"
)
flash = Flash(error_conn)
with pytest.raises(FlashError, match="flash write_image"):
from pathlib import Path
await flash.write_image(Path("/tmp/fake_firmware.bin"), verify=False)
async def test_protect_error(self, error_server, error_conn):
"""flash.protect() with error response raises FlashError."""
error_server.add_response(
r"^flash protect\s+", "error: protection change not supported"
)
flash = Flash(error_conn)
with pytest.raises(FlashError, match="flash protect"):
await flash.protect(bank=0, first=0, last=3, on=True)
# ======================================================================
# 6. Breakpoint error paths
# ======================================================================
class TestBreakpointErrors:
"""Error conditions from the BreakpointManager subsystem."""
async def test_add_breakpoint_error(self, error_server, error_conn):
"""add() with error response raises BreakpointError."""
error_server.add_response(
r"^bp\s+0x", "error: can not add breakpoint, resource not available"
)
bp = BreakpointManager(error_conn)
with pytest.raises(BreakpointError, match="bp 0x"):
await bp.add(0x08001234)
async def test_remove_breakpoint_error(self, error_server, error_conn):
"""remove() with error response raises BreakpointError."""
error_server.add_response(
r"^rbp\s+", "error: no breakpoint at address"
)
bp = BreakpointManager(error_conn)
with pytest.raises(BreakpointError, match="rbp 0x"):
await bp.remove(0x08001234)
async def test_add_watchpoint_error(self, error_server, error_conn):
"""add_watchpoint() with error response raises BreakpointError."""
error_server.add_response(
r"^wp\s+0x", "error: no free watchpoint comparator"
)
bp = BreakpointManager(error_conn)
with pytest.raises(BreakpointError, match="wp 0x"):
await bp.add_watchpoint(0x20000000, 4)
async def test_remove_watchpoint_error(self, error_server, error_conn):
"""remove_watchpoint() with error response raises BreakpointError."""
error_server.add_response(
r"^rwp\s+", "error: no watchpoint at address"
)
bp = BreakpointManager(error_conn)
with pytest.raises(BreakpointError, match="rwp 0x"):
await bp.remove_watchpoint(0x20000000)
# ======================================================================
# 7. Session error paths
# ======================================================================
class TestSessionErrors:
"""Error conditions from the Session layer."""
async def test_get_or_create_loop_from_async_context(self):
"""_get_or_create_loop() inside a running loop raises RuntimeError."""
with pytest.raises(RuntimeError, match="Cannot use sync API"):
_get_or_create_loop()
async def test_command_on_closed_session(self, mock_ocd):
"""session.command() after close() raises ConnectionError."""
host, port, _server = mock_ocd
sess = await Session.connect(host, port, timeout=5.0)
await sess.close()
with pytest.raises(ConnectionError, match="Not connected"):
await sess.command("targets")
async def test_connect_to_nonexistent_host(self):
"""Session.connect() to a bogus address raises ConnectionError."""
with pytest.raises(ConnectionError):
await Session.connect("127.0.0.1", 1, timeout=1.0)
async def test_double_close_is_safe(self, mock_ocd):
"""Calling close() twice on a session should not raise."""
host, port, _server = mock_ocd
sess = await Session.connect(host, port, timeout=5.0)
await sess.close()
await sess.close() # should be a no-op
# ======================================================================
# 8. Process error paths
# ======================================================================
class TestProcessErrors:
"""Error conditions from the OpenOCDProcess manager."""
async def test_start_empty_config(self):
"""start() with an empty config string raises ProcessError."""
proc = OpenOCDProcess()
with pytest.raises(ProcessError, match="[Ee]mpty config"):
# Use /bin/true as a stand-in binary so we reach config validation
await proc.start("", openocd_bin="/bin/true")
async def test_start_dangling_flag(self):
"""start() with a trailing -f and no argument raises ProcessError."""
proc = OpenOCDProcess()
with pytest.raises(ProcessError, match="requires an argument"):
await proc.start("-f", openocd_bin="/bin/true")
async def test_start_dangling_c_flag(self):
"""start() with a trailing -c and no argument raises ProcessError."""
proc = OpenOCDProcess()
with pytest.raises(ProcessError, match="requires an argument"):
await proc.start("-c", openocd_bin="/bin/true")
async def test_start_nonexistent_binary(self):
"""start() with a nonexistent binary path raises ProcessError."""
proc = OpenOCDProcess()
with pytest.raises(ProcessError):
await proc.start(
"interface/cmsis-dap.cfg",
openocd_bin="/nonexistent/path/to/openocd",
)
async def test_pid_is_none_before_start(self):
"""pid property is None before start()."""
proc = OpenOCDProcess()
assert proc.pid is None
async def test_running_is_false_before_start(self):
"""running property is False before start()."""
proc = OpenOCDProcess()
assert proc.running is False
async def test_stop_before_start_is_safe(self):
"""stop() before start() should not raise."""
proc = OpenOCDProcess()
await proc.stop() # no-op, no exception
# ======================================================================
# 9. Notification connection error paths
# ======================================================================
class TestNotificationErrors:
"""Error conditions for the notification subsystem."""
async def test_enable_notifications_before_connect(self):
"""enable_notifications() before connect() raises ConnectionError."""
conn = TclRpcConnection(timeout=1.0)
with pytest.raises(ConnectionError, match="Not connected"):
await conn.enable_notifications()