Add SWD/DAP subsystem with DP/AP register access and AP enumeration

10th subsystem: session.swd provides DAP discovery, DP/AP register
read/write, AP enumeration, and convenience methods (dpidr, target_id).
Includes SWDError, DAPInfo, APInfo types, input validation, ADIv5/v6
AP classification, and 24 mock-only tests covering happy and error paths.
This commit is contained in:
Ryan Malloy 2026-02-15 16:36:25 -07:00
parent 7a893cb328
commit d17037f2a1
10 changed files with 763 additions and 13 deletions

View File

@ -7,14 +7,17 @@ from openocd.errors import (
OpenOCDError,
ProcessError,
SVDError,
SWDError,
TargetError,
TargetNotHaltedError,
TimeoutError,
)
from openocd.session import Session, SyncSession
from openocd.types import (
APInfo,
BitField,
Breakpoint,
DAPInfo,
DecodedRegister,
FlashBank,
FlashSector,
@ -32,8 +35,10 @@ __all__ = [
"Session",
"SyncSession",
# Types
"APInfo",
"BitField",
"Breakpoint",
"DAPInfo",
"DecodedRegister",
"FlashBank",
"FlashSector",
@ -51,6 +56,7 @@ __all__ = [
"OpenOCDError",
"ProcessError",
"SVDError",
"SWDError",
"TargetError",
"TargetNotHaltedError",
"TimeoutError",

View File

@ -39,5 +39,9 @@ class SVDError(OpenOCDError):
"""SVD file not found, failed to parse, or lookup error."""
class SWDError(OpenOCDError):
"""Raised when an SWD/DAP operation fails."""
class ProcessError(OpenOCDError):
"""OpenOCD subprocess failed to start or exited unexpectedly."""

View File

@ -23,6 +23,7 @@ if TYPE_CHECKING:
from openocd.registers import Registers, SyncRegisters
from openocd.rtt import RTTManager
from openocd.svd import SVDManager, SyncSVDManager
from openocd.swd import SWDController, SyncSWDController
from openocd.target import SyncTarget, Target
from openocd.transport import Transport
@ -40,6 +41,7 @@ class Session:
self._registers: Registers | None = None
self._flash: Flash | None = None
self._jtag: JTAGController | None = None
self._swd: SWDController | None = None
self._breakpoints: BreakpointManager | None = None
self._rtt: RTTManager | None = None
self._svd: SVDManager | None = None
@ -145,6 +147,7 @@ class Session:
def target(self) -> Target:
if self._target is None:
from openocd.target import Target
self._target = Target(self._conn)
return self._target
@ -152,6 +155,7 @@ class Session:
def memory(self) -> Memory:
if self._memory is None:
from openocd.memory import Memory
self._memory = Memory(self._conn)
return self._memory
@ -159,6 +163,7 @@ class Session:
def registers(self) -> Registers:
if self._registers is None:
from openocd.registers import Registers
self._registers = Registers(self._conn)
return self._registers
@ -166,6 +171,7 @@ class Session:
def flash(self) -> Flash:
if self._flash is None:
from openocd.flash import Flash
self._flash = Flash(self._conn)
return self._flash
@ -173,13 +179,23 @@ class Session:
def jtag(self) -> JTAGController:
if self._jtag is None:
from openocd.jtag import JTAGController
self._jtag = JTAGController(self._conn)
return self._jtag
@property
def swd(self) -> SWDController:
if self._swd is None:
from openocd.swd import SWDController
self._swd = SWDController(self._conn)
return self._swd
@property
def breakpoints(self) -> BreakpointManager:
if self._breakpoints is None:
from openocd.breakpoints import BreakpointManager
self._breakpoints = BreakpointManager(self._conn)
return self._breakpoints
@ -187,6 +203,7 @@ class Session:
def rtt(self) -> RTTManager:
if self._rtt is None:
from openocd.rtt import RTTManager
self._rtt = RTTManager(self._conn)
return self._rtt
@ -194,6 +211,7 @@ class Session:
def svd(self) -> SVDManager:
if self._svd is None:
from openocd.svd import SVDManager
self._svd = SVDManager(self._conn, self.memory)
return self._svd
@ -201,6 +219,7 @@ class Session:
def transport(self) -> Transport:
if self._transport is None:
from openocd.transport import Transport
self._transport = Transport(self._conn)
return self._transport
@ -210,16 +229,20 @@ class Session:
def on_halt(self, callback: Callable[[str], None]) -> None:
"""Register a callback for target halt events."""
def _filter(msg: str) -> None:
if "halted" in msg.lower():
callback(msg)
self._conn.on_notification(_filter)
def on_reset(self, callback: Callable[[str], None]) -> None:
"""Register a callback for target reset events."""
def _filter(msg: str) -> None:
if "reset" in msg.lower():
callback(msg)
self._conn.on_notification(_filter)
@ -227,6 +250,7 @@ class Session:
# Sync wrapper
# ======================================================================
class SyncSession:
"""Wraps an async Session for synchronous use."""
@ -238,6 +262,7 @@ class SyncSession:
self._registers: SyncRegisters | None = None
self._flash: SyncFlash | None = None
self._jtag: SyncJTAGController | None = None
self._swd: SyncSWDController | None = None
self._breakpoints: SyncBreakpointManager | None = None
self._svd: SyncSVDManager | None = None
@ -254,6 +279,7 @@ class SyncSession:
def target(self) -> SyncTarget:
if self._target is None:
from openocd.target import SyncTarget
self._target = SyncTarget(self._session.target, self._loop)
return self._target
@ -261,6 +287,7 @@ class SyncSession:
def memory(self) -> SyncMemory:
if self._memory is None:
from openocd.memory import SyncMemory
self._memory = SyncMemory(self._session.memory, self._loop)
return self._memory
@ -268,6 +295,7 @@ class SyncSession:
def registers(self) -> SyncRegisters:
if self._registers is None:
from openocd.registers import SyncRegisters
self._registers = SyncRegisters(self._session.registers, self._loop)
return self._registers
@ -275,6 +303,7 @@ class SyncSession:
def flash(self) -> SyncFlash:
if self._flash is None:
from openocd.flash import SyncFlash
self._flash = SyncFlash(self._session.flash, self._loop)
return self._flash
@ -282,13 +311,23 @@ class SyncSession:
def jtag(self) -> SyncJTAGController:
if self._jtag is None:
from openocd.jtag import SyncJTAGController
self._jtag = SyncJTAGController(self._session.jtag, self._loop)
return self._jtag
@property
def swd(self) -> SyncSWDController:
if self._swd is None:
from openocd.swd import SyncSWDController
self._swd = SyncSWDController(self._session.swd, self._loop)
return self._swd
@property
def breakpoints(self) -> SyncBreakpointManager:
if self._breakpoints is None:
from openocd.breakpoints import SyncBreakpointManager
self._breakpoints = SyncBreakpointManager(self._session.breakpoints, self._loop)
return self._breakpoints
@ -296,6 +335,7 @@ class SyncSession:
def svd(self) -> SyncSVDManager:
if self._svd is None:
from openocd.svd import SyncSVDManager
self._svd = SyncSVDManager(self._session.svd, self._loop)
return self._svd
@ -304,6 +344,7 @@ class SyncSession:
# Helpers
# ======================================================================
def _get_or_create_loop() -> asyncio.AbstractEventLoop:
"""Get or create an event loop for synchronous usage.

View File

@ -0,0 +1,5 @@
"""SWD/DAP operations: DP/AP register access and DAP discovery."""
from openocd.swd.controller import SWDController, SyncSWDController
__all__ = ["SWDController", "SyncSWDController"]

View File

@ -0,0 +1,148 @@
"""SWDController — unified facade for SWD/DAP operations."""
from __future__ import annotations
import asyncio
import logging
from openocd.connection.base import Connection
from openocd.errors import SWDError
from openocd.swd import dap as _dap
from openocd.types import APInfo, DAPInfo
log = logging.getLogger(__name__)
class SWDController:
"""High-level async interface to SWD/DAP operations.
Most boards have a single DAP. When *dap* is ``None``, the controller
auto-discovers via ``dap names`` and uses the first (or only) DAP.
Multi-DAP boards (e.g. STM32H7 dual-core) pass the DAP name explicitly.
"""
def __init__(self, conn: Connection) -> None:
self._conn = conn
self._cached_dap: str | None = None
# -- DAP name resolution -----------------------------------------------
async def _resolve_dap(self, dap: str | None) -> str:
"""Return the DAP name to use: explicit or auto-discovered.
When *dap* is ``None``, uses the auto-discovered DAP (first result
from ``dap names``). Once resolved, the name is cached for the
lifetime of this controller unless :meth:`invalidate_cache` is called.
"""
if dap is not None:
return dap
if self._cached_dap is not None:
return self._cached_dap
names = await _dap.dap_names(self._conn)
if not names:
raise SWDError("No DAP instances found (is the transport set to SWD?)")
self._cached_dap = names[0]
log.debug("Auto-resolved DAP: %s", self._cached_dap)
return self._cached_dap
def invalidate_cache(self) -> None:
"""Clear the cached DAP name.
Call after transport changes, probe reconnection, or target
reconfiguration that may change which DAPs are available.
"""
self._cached_dap = None
log.debug("DAP cache invalidated")
# -- DAP discovery -----------------------------------------------------
async def info(self, dap: str | None = None) -> DAPInfo:
"""Query DAP information."""
name = await self._resolve_dap(dap)
return await _dap.dap_info(self._conn, name)
async def list_aps(self, dap: str | None = None) -> list[APInfo]:
"""Enumerate Access Ports on the DAP."""
name = await self._resolve_dap(dap)
return await _dap.enumerate_aps(self._conn, name)
# -- DP register access ------------------------------------------------
async def dpreg(self, address: int, value: int | None = None, *, dap: str | None = None) -> int:
"""Read or write a DP register.
When *value* is ``None``, performs a read and returns the value.
When *value* is provided, performs a write and returns the written value.
"""
name = await self._resolve_dap(dap)
if value is None:
return await _dap.dpreg_read(self._conn, name, address)
await _dap.dpreg_write(self._conn, name, address, value)
return value
# -- AP register access ------------------------------------------------
async def apreg(
self, ap: int, address: int, value: int | None = None, *, dap: str | None = None
) -> int:
"""Read or write an AP register.
When *value* is ``None``, performs a read and returns the value.
When *value* is provided, performs a write and returns the written value.
"""
name = await self._resolve_dap(dap)
if value is None:
return await _dap.apreg_read(self._conn, name, ap, address)
await _dap.apreg_write(self._conn, name, ap, address, value)
return value
# -- Convenience: well-known DP registers ------------------------------
async def dpidr(self, dap: str | None = None) -> int:
"""Read the DP IDR (address 0x0) — identifies the debug port."""
return await self.dpreg(0x0, dap=dap)
async def target_id(self, dap: str | None = None) -> int:
"""Read the TARGETID register (DP address 0x24, DPv2+)."""
return await self.dpreg(0x24, dap=dap)
# ======================================================================
# SyncSWDController — blocking wrappers
# ======================================================================
class SyncSWDController:
"""Synchronous wrapper around :class:`SWDController`.
Every async method is exposed with the same signature but runs
through ``loop.run_until_complete``.
"""
def __init__(self, ctrl: SWDController, loop: asyncio.AbstractEventLoop) -> None:
self._ctrl = ctrl
self._loop = loop
def info(self, dap: str | None = None) -> DAPInfo:
return self._loop.run_until_complete(self._ctrl.info(dap))
def list_aps(self, dap: str | None = None) -> list[APInfo]:
return self._loop.run_until_complete(self._ctrl.list_aps(dap))
def dpreg(self, address: int, value: int | None = None, *, dap: str | None = None) -> int:
return self._loop.run_until_complete(self._ctrl.dpreg(address, value, dap=dap))
def apreg(
self, ap: int, address: int, value: int | None = None, *, dap: str | None = None
) -> int:
return self._loop.run_until_complete(self._ctrl.apreg(ap, address, value, dap=dap))
def dpidr(self, dap: str | None = None) -> int:
return self._loop.run_until_complete(self._ctrl.dpidr(dap))
def target_id(self, dap: str | None = None) -> int:
return self._loop.run_until_complete(self._ctrl.target_id(dap))
def invalidate_cache(self) -> None:
self._ctrl.invalidate_cache()

199
src/openocd/swd/dap.py Normal file
View File

@ -0,0 +1,199 @@
"""Low-level DAP functions for SWD/DAP register access.
All functions take a connection and a DAP name, then issue the
corresponding OpenOCD ``<dap>`` sub-commands. Parsing is defensive
because OpenOCD output varies between versions.
"""
from __future__ import annotations
import logging
import re
from openocd.connection.base import Connection
from openocd.errors import SWDError
from openocd.types import APInfo, DAPInfo
log = logging.getLogger(__name__)
# Match a hex value anywhere in the response (OpenOCD returns "0x2ba01477\n")
_HEX_RE = re.compile(r"0x([0-9a-fA-F]+)")
# Count APs in dap info output — looks for "AP # <n>" lines
_AP_NUM_RE = re.compile(r"AP\s*#?\s*(\d+)")
# DPIDR line in dap info output
_DPIDR_RE = re.compile(r"DPIDR\s*[:=]?\s*(0x[0-9a-fA-F]+)", re.IGNORECASE)
# OpenOCD error patterns: match the structure of actual error responses,
# not arbitrary English words. Avoids false positives on output like
# "error detection enabled" or register descriptions containing "invalid".
_ERROR_RE = re.compile(
r"^Error:|^invalid command|^invalid|command not found",
re.IGNORECASE | re.MULTILINE,
)
_U32_MAX = 0xFFFFFFFF
_AP_MAX = 255
def _validate_u32(value: int, name: str) -> None:
"""Ensure value is a valid unsigned 32-bit integer."""
if not isinstance(value, int) or value < 0 or value > _U32_MAX:
raise SWDError(f"{name} must be 0..0xFFFFFFFF, got {value!r}")
def _validate_ap_num(ap_num: int) -> None:
"""Ensure AP number is in the valid range (0-255 per ARM ADI spec)."""
if not isinstance(ap_num, int) or ap_num < 0 or ap_num > _AP_MAX:
raise SWDError(f"AP number must be 0..255, got {ap_num!r}")
def _parse_hex(resp: str, context: str) -> int:
"""Extract the first hex value from an OpenOCD response string."""
m = _HEX_RE.search(resp)
if m is None:
raise SWDError(f"{context}: no hex value in response: {resp.strip()!r}")
return int(m.group(1), 16)
def _check_error(resp: str, context: str) -> None:
"""Raise SWDError if the response indicates a failure.
Matches OpenOCD's actual error response patterns (``Error:``,
``invalid command``) rather than naive substring matching, to avoid
false positives on legitimate output containing words like "error".
"""
if _ERROR_RE.search(resp):
raise SWDError(f"{context}: {resp.strip()}")
async def dap_names(conn: Connection) -> list[str]:
"""Return the list of DAP instance names known to OpenOCD."""
resp = await conn.send("dap names")
_check_error(resp, "dap names")
names = [n.strip() for n in resp.strip().splitlines() if n.strip()]
return names
async def dap_info(conn: Connection, dap_name: str) -> DAPInfo:
"""Query full DAP info and return a structured DAPInfo."""
resp = await conn.send(f"{dap_name} info")
_check_error(resp, f"{dap_name} info")
# Extract DPIDR
dpidr = 0
m = _DPIDR_RE.search(resp)
if m:
dpidr = int(m.group(1), 16)
else:
log.warning(
"Could not parse DPIDR from '%s info' output — "
"OpenOCD format may have changed. Raw: %.200s",
dap_name,
resp,
)
# Count APs mentioned
ap_indices = set(_AP_NUM_RE.findall(resp))
ap_count = len(ap_indices)
return DAPInfo(
name=dap_name,
dpidr=dpidr,
ap_count=ap_count,
raw_info=resp.strip(),
)
async def dpreg_read(conn: Connection, dap_name: str, address: int) -> int:
"""Read a DP register at *address* via ``<dap> dpreg <addr>``."""
_validate_u32(address, "DP register address")
cmd = f"{dap_name} dpreg {address:#x}"
resp = await conn.send(cmd)
_check_error(resp, cmd)
return _parse_hex(resp, cmd)
async def dpreg_write(conn: Connection, dap_name: str, address: int, value: int) -> None:
"""Write *value* to DP register at *address*."""
_validate_u32(address, "DP register address")
_validate_u32(value, "DP register value")
cmd = f"{dap_name} dpreg {address:#x} {value:#x}"
resp = await conn.send(cmd)
_check_error(resp, cmd)
async def apreg_read(conn: Connection, dap_name: str, ap_num: int, address: int) -> int:
"""Read an AP register: ``<dap> apreg <ap> <addr>``."""
_validate_ap_num(ap_num)
_validate_u32(address, "AP register address")
cmd = f"{dap_name} apreg {ap_num} {address:#x}"
resp = await conn.send(cmd)
_check_error(resp, cmd)
return _parse_hex(resp, cmd)
async def apreg_write(
conn: Connection, dap_name: str, ap_num: int, address: int, value: int
) -> None:
"""Write *value* to AP register: ``<dap> apreg <ap> <addr> <val>``."""
_validate_ap_num(ap_num)
_validate_u32(address, "AP register address")
_validate_u32(value, "AP register value")
cmd = f"{dap_name} apreg {ap_num} {address:#x} {value:#x}"
resp = await conn.send(cmd)
_check_error(resp, cmd)
def _classify_ap(idr: int) -> str:
"""Classify an AP by its IDR value.
The AP IDR Class field (bits 16:13) indicates the AP type per ARM ADI:
0x0 = no AP / reserved
0x1 = COM-AP (deprecated MEM-AP variant, ADIv5)
0x8 = MEM-AP (ADIv5)
0x9 = MEM-AP (ADIv6)
The Type field (bits 3:0) further distinguishes variants.
"""
if idr == 0:
return "unknown"
class_field = (idr >> 13) & 0xF
if class_field in (0x1, 0x8, 0x9):
return "MEM-AP"
type_field = idr & 0xF
if type_field == 0x0:
return "JTAG-AP"
return "unknown"
async def enumerate_aps(conn: Connection, dap_name: str, max_aps: int = 256) -> list[APInfo]:
"""Probe APs by reading IDR (offset 0xFC) until we get 0 or hit *max_aps*.
Each AP with a non-zero IDR is included. We also read the BASE register
(offset 0xF8) to capture the ROM table address.
"""
aps: list[APInfo] = []
for idx in range(max_aps):
try:
idr = await apreg_read(conn, dap_name, idx, 0xFC)
except SWDError as exc:
log.warning("AP enumeration stopped at index %d due to error: %s", idx, exc)
break
if idr == 0:
break
try:
base = await apreg_read(conn, dap_name, idx, 0xF8)
except SWDError:
base = 0
aps.append(
APInfo(
index=idx,
idr=idr,
base=base,
ap_type=_classify_ap(idr),
)
)
return aps

View File

@ -10,6 +10,7 @@ from typing import Literal
# Target
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class TargetState:
"""Snapshot of target execution state."""
@ -23,6 +24,7 @@ class TargetState:
# Registers
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Register:
"""A single CPU register."""
@ -38,6 +40,7 @@ class Register:
# Flash
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class FlashSector:
"""One sector inside a flash bank."""
@ -66,6 +69,7 @@ class FlashBank:
# JTAG
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class TAPInfo:
"""One TAP discovered on the JTAG chain."""
@ -103,6 +107,7 @@ class JTAGState(str, Enum):
# Memory
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class MemoryRegion:
"""A chunk of memory read from the target."""
@ -116,6 +121,7 @@ class MemoryRegion:
# SVD
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class BitField:
"""One decoded bitfield inside a register."""
@ -150,6 +156,7 @@ class DecodedRegister:
# Breakpoints
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class Breakpoint:
"""An active breakpoint."""
@ -175,6 +182,7 @@ class Watchpoint:
# RTT
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class RTTChannel:
"""An RTT channel descriptor."""
@ -183,3 +191,28 @@ class RTTChannel:
name: str
size: int
direction: Literal["up", "down"]
# ---------------------------------------------------------------------------
# SWD / DAP
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class DAPInfo:
"""Debug Access Port information returned by ``dap info``."""
name: str # DAP instance name (e.g. "stm32f1x.dap")
dpidr: int # DP ID Register value
ap_count: int # Number of access ports discovered
raw_info: str # Full ``dap info`` output for detailed parsing
@dataclass(frozen=True)
class APInfo:
"""Access Port descriptor discovered during AP enumeration."""
index: int # AP number (0, 1, 2...)
idr: int # AP ID Register (from apreg <n> 0xfc)
base: int # ROM table base address (from apreg <n> 0xf8)
ap_type: str # "MEM-AP", "JTAG-AP", "CTRL-AP", or "unknown"

View File

@ -7,6 +7,7 @@ An asyncio TCP server that speaks the OpenOCD TCL RPC framing protocol:
Supports exact-match and regex-based command routing with pre-loaded
responses that mirror real OpenOCD output.
"""
from __future__ import annotations
import asyncio
@ -58,8 +59,7 @@ REG_ALL_RESPONSE = """\
READ_MEMORY_RESPONSE = "20005000 080001a1 080001ab 080001ad"
FLASH_BANKS_RESPONSE = (
"#0 : stm32f1x.flash (stm32f1x) at 0x08000000,"
" size 0x00020000, buswidth 0, chipwidth 0"
"#0 : stm32f1x.flash (stm32f1x) at 0x08000000, size 0x00020000, buswidth 0, chipwidth 0"
)
SCAN_CHAIN_RESPONSE = """\
@ -82,6 +82,26 @@ TRANSPORT_SELECT_RESPONSE = "swd"
TRANSPORT_LIST_RESPONSE = "jtag swd"
ADAPTER_SPEED_RESPONSE = "4000"
# -- SWD/DAP ---------------------------------------------------------------
DAP_NAMES_RESPONSE = "stm32f1x.dap"
DAP_INFO_RESPONSE = """\
AP # 0
AP ID register 0x04770031
Type is MEM-AP AHB3
MEM-AP BASE 0xe00ff003
Valid ROM table present
Component base address 0xe00ff000
Peripheral ID 0x04c0010471
Designer is 0x4bb, ST Microelectronics
DPIDR: 0x2ba01477"""
DPREG_0_RESPONSE = "0x2ba01477"
DPREG_24_RESPONSE = "0x00000477"
APREG_0_FC_RESPONSE = "0x04770031"
APREG_0_F8_RESPONSE = "0xe00ff003"
APREG_1_FC_RESPONSE = "0x00000000"
def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[str], str]]]:
"""Build the default command-to-response routing table.
@ -97,7 +117,6 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st
(re.compile(r"^step"), ""),
(re.compile(r"^reset\s+"), ""),
(re.compile(r"^wait_halt"), ""),
# individual register reads (must come before bare "reg")
(re.compile(r"^reg\s+pc$"), REG_PC_RESPONSE),
(re.compile(r"^reg\s+sp$"), REG_SP_RESPONSE),
@ -107,24 +126,30 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st
(re.compile(r"^reg\s+\S+\s+0x"), ""),
# bare "reg" -> full listing
(re.compile(r"^reg$"), REG_ALL_RESPONSE),
# memory
(re.compile(r"^read_memory\s+0x8000000\s+32\s+4$"), READ_MEMORY_RESPONSE),
# generic read_memory -- return zeros for widths/counts we haven't mapped
(re.compile(r"^read_memory\s+"), _generic_read_memory),
(re.compile(r"^write_memory\s+"), ""),
# flash
(re.compile(r"^flash banks$"), FLASH_BANKS_RESPONSE),
(re.compile(r"^flash\s+"), ""),
# SWD/DAP
(re.compile(r"^dap names$"), DAP_NAMES_RESPONSE),
(re.compile(r"^stm32f1x\.dap info$"), DAP_INFO_RESPONSE),
(re.compile(r"^stm32f1x\.dap dpreg 0x0$"), DPREG_0_RESPONSE),
(re.compile(r"^stm32f1x\.dap dpreg 0x24$"), DPREG_24_RESPONSE),
(re.compile(r"^stm32f1x\.dap dpreg 0x0 0x"), ""),
(re.compile(r"^stm32f1x\.dap apreg 0 0xfc$"), APREG_0_FC_RESPONSE),
(re.compile(r"^stm32f1x\.dap apreg 0 0xf8$"), APREG_0_F8_RESPONSE),
(re.compile(r"^stm32f1x\.dap apreg 1 0xfc$"), APREG_1_FC_RESPONSE),
(re.compile(r"^stm32f1x\.dap apreg 0 0x0 0x"), ""),
# JTAG
(re.compile(r"^scan_chain$"), SCAN_CHAIN_RESPONSE),
(re.compile(r"^irscan\s+"), "0x01"),
(re.compile(r"^drscan\s+"), "0xDEADBEEF"),
(re.compile(r"^runtest\s+"), ""),
(re.compile(r"^pathmove\s+"), ""),
# breakpoints
(re.compile(r"^bp\s+0x"), ""),
(re.compile(r"^bp$"), BP_LIST_RESPONSE),
@ -132,14 +157,12 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st
(re.compile(r"^wp\s+0x"), ""),
(re.compile(r"^wp$"), ""),
(re.compile(r"^rwp\s+"), ""),
# transport / adapter
(re.compile(r"^transport\s+select$"), TRANSPORT_SELECT_RESPONSE),
(re.compile(r"^transport\s+list$"), TRANSPORT_LIST_RESPONSE),
(re.compile(r"^adapter\s+speed$"), ADAPTER_SPEED_RESPONSE),
(re.compile(r"^adapter\s+speed\s+\d+"), ADAPTER_SPEED_RESPONSE),
(re.compile(r"^adapter\s+name$"), "cmsis-dap"),
# RTT
(re.compile(r"^rtt\s+channels$"), RTT_CHANNELS_RESPONSE),
(re.compile(r"^rtt\s+setup\s+"), ""),
@ -147,7 +170,6 @@ def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[st
(re.compile(r"^rtt\s+stop$"), ""),
(re.compile(r"^rtt\s+channelread\s+"), "hello from target"),
(re.compile(r"^rtt\s+channelwrite\s+"), ""),
# notifications
(re.compile(r"^tcl_notifications\s+"), ""),
]
@ -201,9 +223,7 @@ class MockOpenOCDServer:
self._routes.insert(0, (re.compile(pattern), response))
async def start(self) -> None:
self._server = await asyncio.start_server(
self._handle_client, self._host, self._port
)
self._server = await asyncio.start_server(self._handle_client, self._host, self._port)
await self._server.start_serving()
async def stop(self) -> None:

View File

@ -1,4 +1,5 @@
"""Tests for the Session class."""
from __future__ import annotations
import pytest
@ -11,6 +12,7 @@ from openocd.registers import Registers
from openocd.rtt import RTTManager
from openocd.session import Session
from openocd.svd import SVDManager
from openocd.swd import SWDController
from openocd.target import Target
from openocd.transport import Transport
@ -38,6 +40,7 @@ async def test_context_manager(mock_ocd):
# After exiting the context, the connection is closed.
# Attempting to send should raise.
from openocd.errors import ConnectionError
with pytest.raises(ConnectionError):
await sess.command("targets")
@ -67,6 +70,11 @@ async def test_subsystem_jtag_type(session):
assert isinstance(session.jtag, JTAGController)
async def test_subsystem_swd_type(session):
"""session.swd should return an SWDController instance."""
assert isinstance(session.swd, SWDController)
async def test_subsystem_breakpoints_type(session):
"""session.breakpoints should return a BreakpointManager instance."""
assert isinstance(session.breakpoints, BreakpointManager)

286
tests/test_swd.py Normal file
View File

@ -0,0 +1,286 @@
"""Tests for the SWD/DAP subsystem."""
from __future__ import annotations
import pytest
from openocd.errors import SWDError
from openocd.types import APInfo, DAPInfo
async def test_dap_info(session):
"""info() should return a DAPInfo with parsed DPIDR and AP count."""
info = await session.swd.info()
assert isinstance(info, DAPInfo)
assert info.name == "stm32f1x.dap"
assert info.dpidr == 0x2BA01477
assert info.ap_count == 1
assert "MEM-AP" in info.raw_info
async def test_dap_info_frozen(session):
"""DAPInfo should be immutable (frozen dataclass)."""
info = await session.swd.info()
with pytest.raises(AttributeError):
info.name = "something_else" # type: ignore[misc]
async def test_dpreg_read(session):
"""dpreg() without a value should read and return a DP register."""
result = await session.swd.dpreg(0x0)
assert isinstance(result, int)
assert result == 0x2BA01477
async def test_dpreg_write(session, mock_ocd):
"""dpreg() with a value should write and return the written value."""
result = await session.swd.dpreg(0x0, value=0x12345678)
assert result == 0x12345678
# Verify the mock received the write command
_, _, server = mock_ocd
write_cmds = [c for c in server.received_commands if "dpreg 0x0 0x" in c]
assert len(write_cmds) >= 1
async def test_apreg_read(session):
"""apreg() without a value should read an AP register."""
result = await session.swd.apreg(0, 0xFC)
assert isinstance(result, int)
assert result == 0x04770031
async def test_apreg_write(session, mock_ocd):
"""apreg() with a value should write and return the written value."""
result = await session.swd.apreg(0, 0x0, value=0xAABBCCDD)
assert result == 0xAABBCCDD
_, _, server = mock_ocd
write_cmds = [c for c in server.received_commands if "apreg 0 0x0 0x" in c]
assert len(write_cmds) >= 1
async def test_enumerate_aps(session):
"""list_aps() should discover APs by probing IDR until zero."""
aps = await session.swd.list_aps()
assert isinstance(aps, list)
assert len(aps) == 1
ap = aps[0]
assert isinstance(ap, APInfo)
assert ap.index == 0
assert ap.idr == 0x04770031
assert ap.base == 0xE00FF003
assert ap.ap_type == "MEM-AP"
async def test_ap_info_frozen(session):
"""APInfo should be immutable (frozen dataclass)."""
aps = await session.swd.list_aps()
with pytest.raises(AttributeError):
aps[0].index = 99 # type: ignore[misc]
async def test_dpidr_convenience(session):
"""dpidr() should read DP address 0x0."""
result = await session.swd.dpidr()
assert result == 0x2BA01477
async def test_target_id(session):
"""target_id() should read DP address 0x24."""
result = await session.swd.target_id()
assert result == 0x00000477
async def test_auto_resolve_dap(session, mock_ocd):
"""With no explicit dap name, the controller should auto-discover."""
# First call triggers dap names lookup
await session.swd.dpidr()
_, _, server = mock_ocd
assert "dap names" in server.received_commands
# Second call should use the cached name (no extra dap names)
count_before = server.received_commands.count("dap names")
await session.swd.dpidr()
count_after = server.received_commands.count("dap names")
assert count_after == count_before
async def test_explicit_dap_name(session, mock_ocd):
"""Passing dap= explicitly should skip auto-discovery."""
result = await session.swd.dpreg(0x0, dap="stm32f1x.dap")
assert result == 0x2BA01477
# Should NOT have called "dap names"
_, _, server = mock_ocd
assert "dap names" not in server.received_commands
async def test_swd_error_on_bad_response(mock_ocd):
"""SWDError should be raised when response matches OpenOCD error patterns."""
from openocd.swd.dap import _check_error
with pytest.raises(SWDError):
_check_error("Error: invalid DAP", "test")
with pytest.raises(SWDError):
_check_error("invalid command name", "test")
with pytest.raises(SWDError):
_check_error("command not found", "test")
# Clean responses should not raise
_check_error("0x2ba01477", "test")
_check_error("", "test")
# Legitimate output containing "error" as a substring should NOT raise.
# This is the false-positive prevention fix (C1 from code review).
_check_error("error detection enabled in CTRL register", "test")
_check_error("AP ID register 0x04770031", "test")
async def test_swd_error_no_hex_value(mock_ocd):
"""SWDError should be raised when no hex value found in read response."""
from openocd.swd.dap import _parse_hex
with pytest.raises(SWDError, match="no hex value"):
_parse_hex("no numbers here", "test read")
def test_sync_wrapper():
"""SyncSWDController should expose the same API synchronously.
The sync API blocks with run_until_complete, so the mock server must
run on a separate thread to accept connections concurrently.
"""
import asyncio
import threading
from openocd.session import Session
from tests.mock_server import MockOpenOCDServer
# Run mock server in a background thread with its own event loop.
bg_loop = asyncio.new_event_loop()
server = MockOpenOCDServer()
bg_loop.run_until_complete(server.start())
host, port = server.address
thread = threading.Thread(target=bg_loop.run_forever, daemon=True)
thread.start()
try:
with Session.connect_sync(host, port, timeout=5.0) as sync_sess:
result = sync_sess.swd.dpidr()
assert result == 0x2BA01477
info = sync_sess.swd.info()
assert isinstance(info, DAPInfo)
assert info.name == "stm32f1x.dap"
aps = sync_sess.swd.list_aps()
assert len(aps) == 1
finally:
bg_loop.call_soon_threadsafe(bg_loop.stop)
thread.join(timeout=5)
bg_loop.run_until_complete(server.stop())
bg_loop.close()
def test_classify_ap():
"""AP classification should identify MEM-AP, JTAG-AP, and unknown types."""
from openocd.swd.dap import _classify_ap
# MEM-AP ADIv5 (class field 0x8)
assert _classify_ap(0x04770031) == "MEM-AP"
# Zero IDR = unknown
assert _classify_ap(0x00000000) == "unknown"
# Class field 0x1 (COM-AP / legacy MEM-AP)
assert _classify_ap(0x00002000) == "MEM-AP"
# MEM-AP ADIv6 (class field 0x9)
assert _classify_ap(0x00012000) == "MEM-AP"
# JTAG-AP: non-zero IDR, class not MEM-AP, type field 0x0
assert _classify_ap(0x00000010) == "JTAG-AP" # bits[3:0]=0x0, class=0
# Unknown: non-zero IDR, class not MEM-AP, type field != 0
assert _classify_ap(0x00000001) == "unknown" # bits[3:0]=0x1, class=0
# ======================================================================
# Error-path tests (from code review findings I6)
# ======================================================================
async def test_no_dap_found(mock_ocd):
"""SWDError should be raised when dap names returns empty."""
from openocd.session import Session
host, port, server = mock_ocd
# Override dap names to return empty
server.add_response(r"^dap names$", "")
sess = await Session.connect(host, port, timeout=5.0)
try:
with pytest.raises(SWDError, match="No DAP instances found"):
await sess.swd.dpidr()
finally:
await sess.close()
async def test_invalidate_cache(session, mock_ocd):
"""invalidate_cache() should force re-discovery on next call."""
_, _, server = mock_ocd
# First call populates the cache
await session.swd.dpidr()
count_after_first = server.received_commands.count("dap names")
assert count_after_first == 1
# Invalidate and call again
session.swd.invalidate_cache()
await session.swd.dpidr()
count_after_invalidate = server.received_commands.count("dap names")
assert count_after_invalidate == 2
async def test_dpreg_negative_address_rejected(session):
"""Negative addresses should be rejected before reaching OpenOCD."""
with pytest.raises(SWDError, match="must be 0"):
await session.swd.dpreg(-1)
async def test_dpreg_overflow_address_rejected(session):
"""Addresses > 0xFFFFFFFF should be rejected."""
with pytest.raises(SWDError, match="must be 0"):
await session.swd.dpreg(0x1_0000_0000)
async def test_apreg_negative_ap_rejected(session):
"""Negative AP numbers should be rejected."""
with pytest.raises(SWDError, match="AP number must be"):
await session.swd.apreg(-1, 0xFC)
async def test_apreg_ap_over_255_rejected(session):
"""AP numbers > 255 should be rejected per ARM ADI spec."""
with pytest.raises(SWDError, match="AP number must be"):
await session.swd.apreg(256, 0xFC)
async def test_dpreg_write_negative_value_rejected(session):
"""Negative values should be rejected for DP register writes."""
with pytest.raises(SWDError, match="must be 0"):
await session.swd.dpreg(0x0, value=-1)
async def test_dap_info_unparseable_dpidr(mock_ocd):
"""When dap info output has no DPIDR line, dpidr should be 0 with warning."""
from openocd.session import Session
host, port, server = mock_ocd
# Override dap info to return output with no DPIDR line
server.add_response(r"^stm32f1x\.dap info$", "AP # 0\n Some AP info\n No DPIDR here")
sess = await Session.connect(host, port, timeout=5.0)
try:
info = await sess.swd.info()
assert info.dpidr == 0 # Falls back to 0 with a logged warning
assert info.ap_count == 1 # AP # 0 is still counted
finally:
await sess.close()