Add openocd-python: typed async-first Python bindings for OpenOCD
Standalone PyPI package providing structured access to the full OpenOCD command surface via the TCL RPC protocol (port 6666). Async-first API with sync wrappers for every method. Subsystems: target control, memory read/write, CPU registers, flash programming, JTAG chain/scan/boundary, breakpoints/watchpoints, SVD peripheral decoding, RTT channels, transport/adapter config. 79 tests passing against a mock TCL RPC server.
This commit is contained in:
commit
7e1eac5e2d
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.venv/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
46
README.md
Normal file
46
README.md
Normal file
@ -0,0 +1,46 @@
|
||||
# openocd-python
|
||||
|
||||
Typed, async-first Python bindings for the full OpenOCD command surface.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pip install openocd-python
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from openocd import Session
|
||||
|
||||
# Connect to a running OpenOCD instance
|
||||
async with Session.connect() as ocd:
|
||||
state = await ocd.target.halt()
|
||||
pc = await ocd.registers.pc()
|
||||
mem = await ocd.memory.read_u32(0x08000000, 4)
|
||||
print(f"PC: {pc:#x}")
|
||||
|
||||
# Or spawn OpenOCD and connect
|
||||
async with Session.start("interface/cmsis-dap.cfg -f target/stm32f1x.cfg") as ocd:
|
||||
await ocd.target.halt()
|
||||
regs = await ocd.registers.read_all()
|
||||
|
||||
# Synchronous API available too
|
||||
with Session.start_sync("interface/cmsis-dap.cfg") as ocd:
|
||||
ocd.target.halt()
|
||||
print(f"PC: {ocd.registers.pc():#x}")
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Async-first** with sync wrappers for every method
|
||||
- **Typed returns** — dataclasses, not raw strings
|
||||
- **Full OpenOCD surface**: target control, memory, registers, flash, JTAG, breakpoints, RTT
|
||||
- **SVD decoding** — read a peripheral register and get named bitfields
|
||||
- **Process management** — spawn and manage OpenOCD subprocesses
|
||||
- **Dual transport** — TCL RPC (primary) and telnet (fallback)
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.11+
|
||||
- OpenOCD installed and on PATH (or pass `openocd_bin=`)
|
||||
61
pyproject.toml
Normal file
61
pyproject.toml
Normal file
@ -0,0 +1,61 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "openocd-python"
|
||||
version = "2025.02.12"
|
||||
description = "Typed, async-first Python bindings for the full OpenOCD command surface"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{name = "Ryan Malloy", email = "ryan@supported.systems"},
|
||||
]
|
||||
keywords = ["openocd", "jtag", "swd", "embedded", "debugging", "hardware"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Debuggers",
|
||||
"Topic :: Software Development :: Embedded Systems",
|
||||
"Topic :: System :: Hardware",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"cmsis-svd>=0.4",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.24",
|
||||
"ruff>=0.8",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
openocd-python = "openocd.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/ryanmalloy/openocd-python"
|
||||
Issues = "https://github.com/ryanmalloy/openocd-python/issues"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/openocd"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "UP", "B", "SIM"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
markers = [
|
||||
"hardware: requires physical DAP-Link hardware (deselect with '-m not hardware')",
|
||||
]
|
||||
64
src/openocd/__init__.py
Normal file
64
src/openocd/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""openocd-python — typed, async-first Python bindings for OpenOCD."""
|
||||
|
||||
from openocd.errors import (
|
||||
ConnectionError,
|
||||
FlashError,
|
||||
JTAGError,
|
||||
OpenOCDError,
|
||||
ProcessError,
|
||||
SVDError,
|
||||
TargetError,
|
||||
TargetNotHaltedError,
|
||||
TimeoutError,
|
||||
)
|
||||
from openocd.session import Session, SyncSession
|
||||
from openocd.types import (
|
||||
BitField,
|
||||
Breakpoint,
|
||||
DecodedRegister,
|
||||
FlashBank,
|
||||
FlashSector,
|
||||
JTAGState,
|
||||
MemoryRegion,
|
||||
Register,
|
||||
RTTChannel,
|
||||
TAPInfo,
|
||||
TargetState,
|
||||
Watchpoint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Session
|
||||
"Session",
|
||||
"SyncSession",
|
||||
# Types
|
||||
"BitField",
|
||||
"Breakpoint",
|
||||
"DecodedRegister",
|
||||
"FlashBank",
|
||||
"FlashSector",
|
||||
"JTAGState",
|
||||
"MemoryRegion",
|
||||
"RTTChannel",
|
||||
"Register",
|
||||
"TAPInfo",
|
||||
"TargetState",
|
||||
"Watchpoint",
|
||||
# Errors
|
||||
"ConnectionError",
|
||||
"FlashError",
|
||||
"JTAGError",
|
||||
"OpenOCDError",
|
||||
"ProcessError",
|
||||
"SVDError",
|
||||
"TargetError",
|
||||
"TargetNotHaltedError",
|
||||
"TimeoutError",
|
||||
]
|
||||
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
__version__ = version("openocd-python")
|
||||
except Exception:
|
||||
__version__ = "0.0.0"
|
||||
234
src/openocd/breakpoints.py
Normal file
234
src/openocd/breakpoints.py
Normal file
@ -0,0 +1,234 @@
|
||||
"""Breakpoint and watchpoint management.
|
||||
|
||||
Wraps OpenOCD's ``bp``, ``rbp``, ``wp``, and ``rwp`` commands for
|
||||
setting, removing, and listing hardware/software breakpoints and
|
||||
data watchpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import OpenOCDError
|
||||
from openocd.types import Breakpoint, Watchpoint
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BreakpointError(OpenOCDError):
|
||||
"""A breakpoint or watchpoint operation failed."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parsers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Breakpoint(IVA): 0x08001234, 0x2, 1 (hw=1) or 0 (sw)
|
||||
_BP_RE = re.compile(
|
||||
r"Breakpoint\([^)]*\):\s*(?P<addr>0x[0-9a-fA-F]+),\s*"
|
||||
r"(?P<len>0x[0-9a-fA-F]+),\s*(?P<hw>\d+)"
|
||||
)
|
||||
|
||||
# Watchpoint output varies across OpenOCD versions. Common formats:
|
||||
# address: 0x20000000, len: 0x4, r/w/a: 2 (access), value: ...
|
||||
# Watchpoint(DWT): 0x20000000, 0x4, 2
|
||||
_WP_RE = re.compile(
|
||||
r"(?:address:\s*(?P<addr1>0x[0-9a-fA-F]+).*?len:\s*(?P<len1>0x[0-9a-fA-F]+).*?r/w/a:\s*(?P<rwa1>\d+))"
|
||||
r"|"
|
||||
r"(?:Watchpoint\([^)]*\):\s*(?P<addr2>0x[0-9a-fA-F]+),\s*(?P<len2>0x[0-9a-fA-F]+),\s*(?P<rwa2>\d+))"
|
||||
)
|
||||
|
||||
# OpenOCD watchpoint access type encoding
|
||||
_WP_ACCESS_MAP = {0: "r", 1: "w", 2: "rw"}
|
||||
_WP_ACCESS_CMD = {"r": "r", "w": "w", "rw": "a"}
|
||||
|
||||
|
||||
def _check_error(response: str, context: str) -> None:
|
||||
"""Raise BreakpointError if the response indicates failure."""
|
||||
if "error" in response.lower():
|
||||
raise BreakpointError(f"{context}: {response.strip()}")
|
||||
|
||||
|
||||
def _parse_breakpoint_list(text: str) -> list[Breakpoint]:
|
||||
"""Parse the output of ``bp`` (no arguments) into Breakpoint objects."""
|
||||
breakpoints: list[Breakpoint] = []
|
||||
for idx, m in enumerate(_BP_RE.finditer(text)):
|
||||
hw_flag = int(m.group("hw"))
|
||||
breakpoints.append(
|
||||
Breakpoint(
|
||||
number=idx,
|
||||
type="hw" if hw_flag else "sw",
|
||||
address=int(m.group("addr"), 16),
|
||||
length=int(m.group("len"), 16),
|
||||
enabled=True,
|
||||
)
|
||||
)
|
||||
return breakpoints
|
||||
|
||||
|
||||
def _parse_watchpoint_list(text: str) -> list[Watchpoint]:
|
||||
"""Parse watchpoint listing output."""
|
||||
watchpoints: list[Watchpoint] = []
|
||||
for idx, m in enumerate(_WP_RE.finditer(text)):
|
||||
# Match could come from either alternative in the regex
|
||||
if m.group("addr1") is not None:
|
||||
addr = int(m.group("addr1"), 16)
|
||||
length = int(m.group("len1"), 16)
|
||||
rwa = int(m.group("rwa1"))
|
||||
else:
|
||||
addr = int(m.group("addr2"), 16)
|
||||
length = int(m.group("len2"), 16)
|
||||
rwa = int(m.group("rwa2"))
|
||||
|
||||
watchpoints.append(
|
||||
Watchpoint(
|
||||
number=idx,
|
||||
address=addr,
|
||||
length=length,
|
||||
access=_WP_ACCESS_MAP.get(rwa, "rw"),
|
||||
)
|
||||
)
|
||||
return watchpoints
|
||||
|
||||
|
||||
class BreakpointManager:
|
||||
"""Manage breakpoints and watchpoints via OpenOCD.
|
||||
|
||||
Breakpoints can be either software (patching the instruction) or
|
||||
hardware (using on-chip comparators). Watchpoints trigger on data
|
||||
access to a given address range.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Breakpoints
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def add(self, address: int, length: int = 2, hw: bool = False) -> None:
|
||||
"""Set a breakpoint at the given address.
|
||||
|
||||
Args:
|
||||
address: Instruction address for the breakpoint.
|
||||
length: Breakpoint length in bytes (2 for Thumb, 4 for ARM).
|
||||
hw: Request a hardware breakpoint. If False, OpenOCD uses a
|
||||
software breakpoint when possible.
|
||||
"""
|
||||
cmd = f"bp 0x{address:08X} {length}"
|
||||
if hw:
|
||||
cmd += " hw"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"bp 0x{address:08X}")
|
||||
log.info("Breakpoint set at 0x%08X (len=%d, hw=%s)", address, length, hw)
|
||||
|
||||
async def remove(self, address: int) -> None:
|
||||
"""Remove a breakpoint at the given address.
|
||||
|
||||
Args:
|
||||
address: Address of the breakpoint to remove.
|
||||
"""
|
||||
cmd = f"rbp 0x{address:08X}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"rbp 0x{address:08X}")
|
||||
log.info("Breakpoint removed at 0x%08X", address)
|
||||
|
||||
async def list(self) -> list[Breakpoint]:
|
||||
"""List all active breakpoints.
|
||||
|
||||
Returns:
|
||||
A list of Breakpoint objects describing each active breakpoint.
|
||||
"""
|
||||
resp = await self._conn.send("bp")
|
||||
# An empty response or no matches means no breakpoints set
|
||||
if not resp.strip():
|
||||
return []
|
||||
return _parse_breakpoint_list(resp)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Watchpoints
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def add_watchpoint(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
access: Literal["r", "w", "rw"] = "rw",
|
||||
) -> None:
|
||||
"""Set a data watchpoint.
|
||||
|
||||
Args:
|
||||
address: Memory address to watch.
|
||||
length: Number of bytes to watch (must be power of 2 on most targets).
|
||||
access: Access type -- ``"r"`` for read, ``"w"`` for write,
|
||||
``"rw"`` for read/write (access).
|
||||
"""
|
||||
access_flag = _WP_ACCESS_CMD.get(access, "a")
|
||||
cmd = f"wp 0x{address:08X} {length} {access_flag}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"wp 0x{address:08X}")
|
||||
log.info("Watchpoint set at 0x%08X (len=%d, access=%s)", address, length, access)
|
||||
|
||||
async def remove_watchpoint(self, address: int) -> None:
|
||||
"""Remove a watchpoint at the given address.
|
||||
|
||||
Args:
|
||||
address: Address of the watchpoint to remove.
|
||||
"""
|
||||
cmd = f"rwp 0x{address:08X}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"rwp 0x{address:08X}")
|
||||
log.info("Watchpoint removed at 0x%08X", address)
|
||||
|
||||
async def list_watchpoints(self) -> list[Watchpoint]:
|
||||
"""List all active watchpoints.
|
||||
|
||||
Returns:
|
||||
A list of Watchpoint objects describing each active watchpoint.
|
||||
"""
|
||||
# OpenOCD doesn't have a dedicated "list watchpoints" command
|
||||
# but 'wp' with no arguments on some builds returns the list.
|
||||
# The more reliable approach is using the TCL command.
|
||||
resp = await self._conn.send("wp")
|
||||
if not resp.strip():
|
||||
return []
|
||||
return _parse_watchpoint_list(resp)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Sync wrapper
|
||||
# ======================================================================
|
||||
|
||||
class SyncBreakpointManager:
|
||||
"""Synchronous wrapper around BreakpointManager."""
|
||||
|
||||
def __init__(self, bp_manager: BreakpointManager, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._bp = bp_manager
|
||||
self._loop = loop
|
||||
|
||||
def add(self, address: int, length: int = 2, hw: bool = False) -> None:
|
||||
self._loop.run_until_complete(self._bp.add(address, length=length, hw=hw))
|
||||
|
||||
def remove(self, address: int) -> None:
|
||||
self._loop.run_until_complete(self._bp.remove(address))
|
||||
|
||||
def list(self) -> list[Breakpoint]:
|
||||
return self._loop.run_until_complete(self._bp.list())
|
||||
|
||||
def add_watchpoint(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
access: Literal["r", "w", "rw"] = "rw",
|
||||
) -> None:
|
||||
self._loop.run_until_complete(self._bp.add_watchpoint(address, length, access=access))
|
||||
|
||||
def remove_watchpoint(self, address: int) -> None:
|
||||
self._loop.run_until_complete(self._bp.remove_watchpoint(address))
|
||||
|
||||
def list_watchpoints(self) -> list[Watchpoint]:
|
||||
return self._loop.run_until_complete(self._bp.list_watchpoints())
|
||||
161
src/openocd/cli.py
Normal file
161
src/openocd/cli.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""CLI entry point for openocd-python.
|
||||
|
||||
Provides quick diagnostics and a REPL for interactive use:
|
||||
|
||||
$ openocd-python --help
|
||||
$ openocd-python info # probe detection + target info
|
||||
$ openocd-python repl # interactive command REPL
|
||||
$ openocd-python read 0x08000000 16 # quick memory read
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
pkg_version = version("openocd-python")
|
||||
except Exception:
|
||||
pkg_version = "dev"
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="openocd-python",
|
||||
description=f"OpenOCD Python bindings v{pkg_version}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version", action="version", version=f"openocd-python {pkg_version}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default="localhost", help="OpenOCD host (default: localhost)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=6666, help="OpenOCD TCL RPC port (default: 6666)"
|
||||
)
|
||||
|
||||
sub = parser.add_subparsers(dest="command")
|
||||
|
||||
sub.add_parser("info", help="Show target and adapter information")
|
||||
|
||||
repl_parser = sub.add_parser("repl", help="Interactive OpenOCD command REPL")
|
||||
repl_parser.add_argument(
|
||||
"--timeout", type=float, default=10.0, help="Command timeout in seconds"
|
||||
)
|
||||
|
||||
read_parser = sub.add_parser("read", help="Read memory and display as hexdump")
|
||||
read_parser.add_argument("address", help="Start address (hex, e.g. 0x08000000)")
|
||||
read_parser.add_argument(
|
||||
"size", type=int, nargs="?", default=64, help="Bytes to read (default: 64)"
|
||||
)
|
||||
|
||||
sub.add_parser("scan", help="Scan the JTAG chain")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command is None:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
asyncio.run(_dispatch(args))
|
||||
|
||||
|
||||
async def _dispatch(args: argparse.Namespace) -> None:
|
||||
from openocd.session import Session
|
||||
|
||||
async with Session.connect(host=args.host, port=args.port) as ocd:
|
||||
if args.command == "info":
|
||||
await _cmd_info(ocd)
|
||||
elif args.command == "repl":
|
||||
await _cmd_repl(ocd, timeout=args.timeout)
|
||||
elif args.command == "read":
|
||||
await _cmd_read(ocd, args.address, args.size)
|
||||
elif args.command == "scan":
|
||||
await _cmd_scan(ocd)
|
||||
|
||||
|
||||
async def _cmd_info(ocd) -> None:
|
||||
"""Display target state and adapter information."""
|
||||
from openocd.errors import OpenOCDError
|
||||
|
||||
print("=== OpenOCD Target Info ===\n")
|
||||
|
||||
try:
|
||||
state = await ocd.target.state()
|
||||
print(f" Target: {state.name}")
|
||||
print(f" State: {state.state}")
|
||||
if state.current_pc is not None:
|
||||
print(f" PC: 0x{state.current_pc:08X}")
|
||||
except OpenOCDError as e:
|
||||
print(f" Target: (error: {e})")
|
||||
|
||||
print()
|
||||
|
||||
try:
|
||||
transport_name = await ocd.transport.select()
|
||||
print(f" Transport: {transport_name}")
|
||||
except OpenOCDError:
|
||||
pass
|
||||
|
||||
try:
|
||||
adapter = await ocd.transport.adapter_info()
|
||||
print(f" Adapter: {adapter}")
|
||||
except OpenOCDError:
|
||||
pass
|
||||
|
||||
try:
|
||||
speed = await ocd.transport.adapter_speed()
|
||||
print(f" Speed: {speed} kHz")
|
||||
except OpenOCDError:
|
||||
pass
|
||||
|
||||
|
||||
async def _cmd_repl(ocd, timeout: float = 10.0) -> None:
|
||||
"""Interactive command REPL."""
|
||||
print("OpenOCD REPL (type 'quit' or Ctrl-D to exit)\n")
|
||||
while True:
|
||||
try:
|
||||
cmd = input("ocd> ")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
break
|
||||
if cmd.strip().lower() in ("quit", "exit", "q"):
|
||||
break
|
||||
if not cmd.strip():
|
||||
continue
|
||||
try:
|
||||
result = await ocd.command(cmd)
|
||||
if result.strip():
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
||||
async def _cmd_read(ocd, address_str: str, size: int) -> None:
|
||||
"""Read memory and display as hexdump."""
|
||||
addr = int(address_str, 0)
|
||||
dump = await ocd.memory.hexdump(addr, size)
|
||||
print(dump)
|
||||
|
||||
|
||||
async def _cmd_scan(ocd) -> None:
|
||||
"""Scan the JTAG chain."""
|
||||
taps = await ocd.jtag.scan_chain()
|
||||
if not taps:
|
||||
print("No TAPs found on the JTAG chain.")
|
||||
return
|
||||
|
||||
print(f"{'TAP Name':<25s} {'IDCODE':>10s} IR Enabled")
|
||||
print("-" * 50)
|
||||
for tap in taps:
|
||||
print(
|
||||
f"{tap.name:<25s} 0x{tap.idcode:08X} {tap.ir_length:>2d} "
|
||||
f"{'yes' if tap.enabled else 'no'}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
src/openocd/connection/__init__.py
Normal file
7
src/openocd/connection/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Connection backends for communicating with OpenOCD."""
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.connection.tcl_rpc import TclRpcConnection
|
||||
from openocd.connection.telnet import TelnetConnection
|
||||
|
||||
__all__ = ["Connection", "TclRpcConnection", "TelnetConnection"]
|
||||
30
src/openocd/connection/base.py
Normal file
30
src/openocd/connection/base.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Abstract base class for OpenOCD connection backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
class Connection(ABC):
|
||||
"""Protocol-agnostic interface to an OpenOCD instance."""
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self, host: str, port: int) -> None:
|
||||
"""Open a connection to the given host and port."""
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, command: str) -> str:
|
||||
"""Send a command string and return the response."""
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
|
||||
@abstractmethod
|
||||
async def enable_notifications(self) -> None:
|
||||
"""Enable asynchronous event notifications from OpenOCD."""
|
||||
|
||||
@abstractmethod
|
||||
def on_notification(self, callback: Callable[[str], None]) -> None:
|
||||
"""Register a callback for incoming notifications."""
|
||||
163
src/openocd/connection/tcl_rpc.py
Normal file
163
src/openocd/connection/tcl_rpc.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""TCL RPC client for OpenOCD (port 6666).
|
||||
|
||||
OpenOCD's TCL RPC uses a simple framing protocol:
|
||||
- Client sends: command_string + \\x1a
|
||||
- Server replies: response_string + \\x1a
|
||||
The \\x1a (ASCII SUB / Ctrl-Z) byte acts as an unambiguous delimiter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import ConnectionError
|
||||
from openocd.errors import TimeoutError as OcdTimeoutError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SEPARATOR = b"\x1a"
|
||||
DEFAULT_TIMEOUT = 10.0
|
||||
|
||||
|
||||
class TclRpcConnection(Connection):
|
||||
"""Async TCP client speaking OpenOCD's TCL RPC protocol."""
|
||||
|
||||
def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None:
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
self._timeout = timeout
|
||||
self._notification_callbacks: list[Callable[[str], None]] = []
|
||||
self._notification_task: asyncio.Task[None] | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
self._host: str = ""
|
||||
self._port: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self, host: str = "localhost", port: int = 6666) -> None:
|
||||
self._host = host
|
||||
self._port = port
|
||||
try:
|
||||
self._reader, self._writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(host, port),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except OSError as exc:
|
||||
raise ConnectionError(
|
||||
f"Cannot connect to OpenOCD TCL RPC at {host}:{port}: {exc}"
|
||||
) from exc
|
||||
except TimeoutError as exc:
|
||||
raise OcdTimeoutError(
|
||||
f"Timed out connecting to OpenOCD TCL RPC at {host}:{port}"
|
||||
) from exc
|
||||
log.debug("Connected to OpenOCD TCL RPC at %s:%d", host, port)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._notification_task and not self._notification_task.done():
|
||||
self._notification_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self._notification_task
|
||||
self._notification_task = None
|
||||
|
||||
if self._writer:
|
||||
self._writer.close()
|
||||
with contextlib.suppress(OSError):
|
||||
await self._writer.wait_closed()
|
||||
self._writer = None
|
||||
self._reader = None
|
||||
log.debug("TCL RPC connection closed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Command send/receive
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(self, command: str) -> str:
|
||||
"""Send a command and return the response string.
|
||||
|
||||
The protocol appends \\x1a after the command and reads until
|
||||
\\x1a appears in the response stream.
|
||||
"""
|
||||
if not self._writer or not self._reader:
|
||||
raise ConnectionError("Not connected — call connect() first")
|
||||
|
||||
async with self._lock:
|
||||
payload = command.encode("utf-8") + SEPARATOR
|
||||
self._writer.write(payload)
|
||||
await self._writer.drain()
|
||||
log.debug("TX: %s", command)
|
||||
|
||||
try:
|
||||
raw = await asyncio.wait_for(
|
||||
self._read_until_separator(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
raise OcdTimeoutError(
|
||||
f"Timed out waiting for response to: {command}"
|
||||
) from exc
|
||||
|
||||
response = raw.decode("utf-8", errors="replace")
|
||||
log.debug("RX: %s", response[:200])
|
||||
return response
|
||||
|
||||
async def _read_until_separator(self) -> bytes:
|
||||
"""Read from the stream until the \\x1a separator is found."""
|
||||
assert self._reader is not None
|
||||
buf = bytearray()
|
||||
while True:
|
||||
chunk = await self._reader.read(4096)
|
||||
if not chunk:
|
||||
raise ConnectionError("OpenOCD closed the connection")
|
||||
buf.extend(chunk)
|
||||
idx = buf.find(SEPARATOR)
|
||||
if idx != -1:
|
||||
return bytes(buf[:idx])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Notifications (async events from OpenOCD)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def enable_notifications(self) -> None:
|
||||
"""Enable TCL event notifications and start the listener loop.
|
||||
|
||||
Sends ``tcl_notifications on`` which causes OpenOCD to push
|
||||
target-state-change events over the same socket.
|
||||
"""
|
||||
await self.send("tcl_notifications on")
|
||||
self._notification_task = asyncio.create_task(self._notification_loop())
|
||||
|
||||
async def _notification_loop(self) -> None:
|
||||
"""Background task that reads unsolicited notifications."""
|
||||
assert self._reader is not None
|
||||
buf = bytearray()
|
||||
try:
|
||||
while True:
|
||||
chunk = await self._reader.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
buf.extend(chunk)
|
||||
while True:
|
||||
idx = buf.find(SEPARATOR)
|
||||
if idx == -1:
|
||||
break
|
||||
msg = buf[:idx].decode("utf-8", errors="replace")
|
||||
buf = buf[idx + 1 :]
|
||||
log.debug("Notification: %s", msg)
|
||||
for cb in self._notification_callbacks:
|
||||
try:
|
||||
cb(msg)
|
||||
except Exception:
|
||||
log.exception("Notification callback error")
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception:
|
||||
log.exception("Notification loop crashed")
|
||||
|
||||
def on_notification(self, callback: Callable[[str], None]) -> None:
|
||||
self._notification_callbacks.append(callback)
|
||||
100
src/openocd/connection/telnet.py
Normal file
100
src/openocd/connection/telnet.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""Telnet connection to OpenOCD (port 4444) — fallback transport.
|
||||
|
||||
The telnet interface is human-oriented and its output formatting varies
|
||||
between OpenOCD versions. Prefer TclRpcConnection where possible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import ConnectionError
|
||||
from openocd.errors import TimeoutError as OcdTimeoutError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
PROMPT = b"> "
|
||||
DEFAULT_TIMEOUT = 10.0
|
||||
|
||||
|
||||
class TelnetConnection(Connection):
|
||||
"""Async telnet client for OpenOCD port 4444."""
|
||||
|
||||
def __init__(self, timeout: float = DEFAULT_TIMEOUT) -> None:
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
self._timeout = timeout
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def connect(self, host: str = "localhost", port: int = 4444) -> None:
|
||||
try:
|
||||
self._reader, self._writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(host, port),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except OSError as exc:
|
||||
raise ConnectionError(
|
||||
f"Cannot connect to OpenOCD telnet at {host}:{port}: {exc}"
|
||||
) from exc
|
||||
except TimeoutError as exc:
|
||||
raise OcdTimeoutError(
|
||||
f"Timed out connecting to OpenOCD telnet at {host}:{port}"
|
||||
) from exc
|
||||
|
||||
# Consume the initial banner / prompt
|
||||
with contextlib.suppress(TimeoutError):
|
||||
await asyncio.wait_for(self._read_until_prompt(), timeout=self._timeout)
|
||||
log.debug("Connected to OpenOCD telnet at %s:%d", host, port)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._writer:
|
||||
self._writer.close()
|
||||
with contextlib.suppress(OSError):
|
||||
await self._writer.wait_closed()
|
||||
self._writer = None
|
||||
self._reader = None
|
||||
|
||||
async def send(self, command: str) -> str:
|
||||
if not self._writer or not self._reader:
|
||||
raise ConnectionError("Not connected")
|
||||
|
||||
async with self._lock:
|
||||
self._writer.write((command + "\n").encode("utf-8"))
|
||||
await self._writer.drain()
|
||||
|
||||
try:
|
||||
raw = await asyncio.wait_for(
|
||||
self._read_until_prompt(),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
raise OcdTimeoutError(f"Timed out waiting for: {command}") from exc
|
||||
|
||||
# Strip the echoed command and trailing prompt
|
||||
text = raw.decode("utf-8", errors="replace")
|
||||
lines = text.splitlines()
|
||||
# First line is often the echoed command, last line is the prompt
|
||||
if lines and lines[0].strip() == command.strip():
|
||||
lines = lines[1:]
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
async def _read_until_prompt(self) -> bytes:
|
||||
assert self._reader is not None
|
||||
buf = bytearray()
|
||||
while True:
|
||||
chunk = await self._reader.read(4096)
|
||||
if not chunk:
|
||||
raise ConnectionError("OpenOCD closed the connection")
|
||||
buf.extend(chunk)
|
||||
if buf.endswith(PROMPT):
|
||||
return bytes(buf[: -len(PROMPT)])
|
||||
|
||||
async def enable_notifications(self) -> None:
|
||||
log.warning("Telnet transport does not support async notifications")
|
||||
|
||||
def on_notification(self, callback: Callable[[str], None]) -> None:
|
||||
log.warning("Telnet transport does not support notifications")
|
||||
43
src/openocd/errors.py
Normal file
43
src/openocd/errors.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""Exception hierarchy for openocd-python.
|
||||
|
||||
All exceptions inherit from OpenOCDError so callers can catch broadly
|
||||
or narrowly as needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class OpenOCDError(Exception):
|
||||
"""Base exception for all openocd-python errors."""
|
||||
|
||||
|
||||
class ConnectionError(OpenOCDError):
|
||||
"""Cannot connect to the OpenOCD TCL RPC or telnet interface."""
|
||||
|
||||
|
||||
class TimeoutError(OpenOCDError):
|
||||
"""A command or wait operation exceeded its deadline."""
|
||||
|
||||
|
||||
class TargetError(OpenOCDError):
|
||||
"""The target is not responding or returned an error."""
|
||||
|
||||
|
||||
class TargetNotHaltedError(TargetError):
|
||||
"""Operation requires a halted target but it is currently running."""
|
||||
|
||||
|
||||
class FlashError(OpenOCDError):
|
||||
"""A flash read/write/erase/verify operation failed."""
|
||||
|
||||
|
||||
class JTAGError(OpenOCDError):
|
||||
"""JTAG communication or chain error."""
|
||||
|
||||
|
||||
class SVDError(OpenOCDError):
|
||||
"""SVD file not found, failed to parse, or lookup error."""
|
||||
|
||||
|
||||
class ProcessError(OpenOCDError):
|
||||
"""OpenOCD subprocess failed to start or exited unexpectedly."""
|
||||
112
src/openocd/events.py
Normal file
112
src/openocd/events.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Async event system for OpenOCD target state changes.
|
||||
|
||||
OpenOCD can push TCL notifications when target state changes occur
|
||||
(halt, resume, reset, etc.). This module provides a typed callback
|
||||
interface on top of the raw notification stream.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openocd.connection.base import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Known event types emitted by OpenOCD's TCL notification system
|
||||
EVENT_HALTED = "halted"
|
||||
EVENT_RESUMED = "resumed"
|
||||
EVENT_RESET = "reset"
|
||||
EVENT_GDB_ATTACHED = "gdb-attached"
|
||||
EVENT_GDB_DETACHED = "gdb-detached"
|
||||
|
||||
|
||||
class EventManager:
|
||||
"""Manages callbacks for target state change events.
|
||||
|
||||
Usage::
|
||||
|
||||
events = EventManager(conn)
|
||||
await events.enable()
|
||||
|
||||
events.on("halted", lambda msg: print(f"Target halted: {msg}"))
|
||||
events.on("reset", handle_reset)
|
||||
|
||||
# Later...
|
||||
events.off("halted", that_callback)
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
self._callbacks: dict[str, list[Callable[[str], None]]] = {}
|
||||
self._enabled = False
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether TCL notifications have been turned on."""
|
||||
return self._enabled
|
||||
|
||||
async def enable(self) -> None:
|
||||
"""Enable TCL notifications and start event dispatch.
|
||||
|
||||
Sends ``tcl_notifications on`` to OpenOCD and registers an
|
||||
internal handler that routes incoming messages to typed callbacks.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If the connection is not open.
|
||||
"""
|
||||
if self._enabled:
|
||||
return
|
||||
|
||||
await self._conn.enable_notifications()
|
||||
self._conn.on_notification(self._dispatch)
|
||||
self._enabled = True
|
||||
log.info("Event notifications enabled")
|
||||
|
||||
def on(self, event_type: str, callback: Callable[[str], None]) -> None:
|
||||
"""Register a callback for a specific event type.
|
||||
|
||||
Args:
|
||||
event_type: Event name to match (e.g. "halted", "reset", "resumed").
|
||||
Matching is case-insensitive substring — a notification
|
||||
containing "halted" anywhere triggers "halted" callbacks.
|
||||
callback: Function to call with the full notification message.
|
||||
"""
|
||||
key = event_type.lower()
|
||||
if key not in self._callbacks:
|
||||
self._callbacks[key] = []
|
||||
if callback not in self._callbacks[key]:
|
||||
self._callbacks[key].append(callback)
|
||||
log.debug("Registered callback for event '%s'", key)
|
||||
|
||||
def off(self, event_type: str, callback: Callable[[str], None]) -> None:
|
||||
"""Unregister a callback.
|
||||
|
||||
Args:
|
||||
event_type: The event type the callback was registered under.
|
||||
callback: The callback to remove.
|
||||
"""
|
||||
key = event_type.lower()
|
||||
handlers = self._callbacks.get(key, [])
|
||||
try:
|
||||
handlers.remove(callback)
|
||||
log.debug("Unregistered callback for event '%s'", key)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _dispatch(self, message: str) -> None:
|
||||
"""Route an incoming notification to matching callbacks."""
|
||||
msg_lower = message.lower()
|
||||
for event_type, handlers in self._callbacks.items():
|
||||
if event_type in msg_lower:
|
||||
for handler in handlers:
|
||||
try:
|
||||
handler(message)
|
||||
except Exception:
|
||||
log.exception(
|
||||
"Error in event callback for '%s'",
|
||||
event_type,
|
||||
)
|
||||
381
src/openocd/flash.py
Normal file
381
src/openocd/flash.py
Normal file
@ -0,0 +1,381 @@
|
||||
"""Flash memory programming operations.
|
||||
|
||||
Wraps OpenOCD's ``flash`` command family for reading, writing, erasing,
|
||||
and verifying on-chip flash memory banks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import FlashError
|
||||
from openocd.types import FlashBank, FlashSector
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regex patterns for parsing OpenOCD flash output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# flash banks: #0 : stm32f1x.flash (stm32f1x) at 0x08000000, size 0x00020000, ...
|
||||
_BANK_LIST_RE = re.compile(
|
||||
r"#(?P<idx>\d+)\s*:\s*(?P<name>\S+)\s+\((?P<driver>\S+)\)\s+"
|
||||
r"at\s+(?P<base>0x[0-9a-fA-F]+),\s*size\s+(?P<size>0x[0-9a-fA-F]+),\s*"
|
||||
r"buswidth\s+(?P<bw>\d+),\s*chipwidth\s+(?P<cw>\d+)"
|
||||
)
|
||||
|
||||
# flash info header: #0 : stm32f1x at 0x08000000, size 0x00020000, ...
|
||||
_INFO_HEADER_RE = re.compile(
|
||||
r"#(?P<idx>\d+)\s*:\s*(?P<name>\S+)\s+at\s+(?P<base>0x[0-9a-fA-F]+),\s*"
|
||||
r"size\s+(?P<size>0x[0-9a-fA-F]+),\s*"
|
||||
r"buswidth\s+(?P<bw>\d+),\s*chipwidth\s+(?P<cw>\d+)"
|
||||
)
|
||||
|
||||
# flash info sector: # 0: 0x00000000 (0x400 1kB) not protected
|
||||
_SECTOR_RE = re.compile(
|
||||
r"#\s*(?P<idx>\d+):\s*(?P<offset>0x[0-9a-fA-F]+)\s+"
|
||||
r"\((?P<size>0x[0-9a-fA-F]+)\s+[^)]*\)\s+"
|
||||
r"(?P<prot>protected|not protected)"
|
||||
)
|
||||
|
||||
|
||||
def _check_error(response: str, context: str) -> None:
|
||||
"""Raise FlashError if the response indicates failure."""
|
||||
if "error" in response.lower():
|
||||
raise FlashError(f"{context}: {response.strip()}")
|
||||
|
||||
|
||||
def _parse_bank_list(text: str) -> list[FlashBank]:
|
||||
"""Parse the output of ``flash banks``."""
|
||||
banks: list[FlashBank] = []
|
||||
for m in _BANK_LIST_RE.finditer(text):
|
||||
banks.append(
|
||||
FlashBank(
|
||||
index=int(m.group("idx")),
|
||||
name=m.group("name"),
|
||||
base=int(m.group("base"), 16),
|
||||
size=int(m.group("size"), 16),
|
||||
bus_width=int(m.group("bw")),
|
||||
chip_width=int(m.group("cw")),
|
||||
target=m.group("driver"),
|
||||
)
|
||||
)
|
||||
return banks
|
||||
|
||||
|
||||
def _parse_bank_info(text: str) -> FlashBank:
|
||||
"""Parse the output of ``flash info <bank>`` into a FlashBank with sectors."""
|
||||
header = _INFO_HEADER_RE.search(text)
|
||||
if not header:
|
||||
raise FlashError(f"Cannot parse flash info output: {text[:200]}")
|
||||
|
||||
sectors: list[FlashSector] = []
|
||||
for m in _SECTOR_RE.finditer(text):
|
||||
sectors.append(
|
||||
FlashSector(
|
||||
index=int(m.group("idx")),
|
||||
offset=int(m.group("offset"), 16),
|
||||
size=int(m.group("size"), 16),
|
||||
protected=m.group("prot") == "protected",
|
||||
)
|
||||
)
|
||||
|
||||
return FlashBank(
|
||||
index=int(header.group("idx")),
|
||||
name=header.group("name"),
|
||||
base=int(header.group("base"), 16),
|
||||
size=int(header.group("size"), 16),
|
||||
bus_width=int(header.group("bw")),
|
||||
chip_width=int(header.group("cw")),
|
||||
target=(
|
||||
header.group("name").split(".")[0]
|
||||
if "." in header.group("name")
|
||||
else header.group("name")
|
||||
),
|
||||
sectors=sectors,
|
||||
)
|
||||
|
||||
|
||||
class Flash:
|
||||
"""Flash memory programming via OpenOCD.
|
||||
|
||||
All methods are async and use the underlying TCL RPC connection to
|
||||
issue ``flash`` commands.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bank enumeration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def banks(self) -> list[FlashBank]:
|
||||
"""List all configured flash banks.
|
||||
|
||||
Returns:
|
||||
A list of FlashBank descriptors (without detailed sector info).
|
||||
"""
|
||||
resp = await self._conn.send("flash banks")
|
||||
_check_error(resp, "flash banks")
|
||||
return _parse_bank_list(resp)
|
||||
|
||||
async def info(self, bank: int = 0) -> FlashBank:
|
||||
"""Get detailed information about a flash bank, including sectors.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number (default 0).
|
||||
|
||||
Returns:
|
||||
A FlashBank with populated ``sectors`` list.
|
||||
"""
|
||||
resp = await self._conn.send(f"flash info {bank}")
|
||||
_check_error(resp, f"flash info {bank}")
|
||||
return _parse_bank_info(resp)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Read operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def read(self, bank: int, offset: int, size: int) -> bytes:
|
||||
"""Read raw bytes from a flash bank.
|
||||
|
||||
Uses a temporary file as an intermediary since OpenOCD's
|
||||
``flash read_bank`` writes to a file, then reads the file content
|
||||
back via TCL.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number.
|
||||
offset: Byte offset within the bank.
|
||||
size: Number of bytes to read.
|
||||
|
||||
Returns:
|
||||
The raw flash contents as bytes.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
cmd = f"flash read_bank {bank} {tmp_path} {offset} {size}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"flash read_bank {bank}")
|
||||
|
||||
# Read the file back through TCL to handle remote OpenOCD instances.
|
||||
# Use ocd_find + binary read if available, otherwise fall back to
|
||||
# reading the local file.
|
||||
tcl_read = (
|
||||
f"set fp [open {tmp_path} rb]; "
|
||||
f"set data [read $fp]; "
|
||||
f"close $fp; "
|
||||
f"set data"
|
||||
)
|
||||
try:
|
||||
raw = await self._conn.send(tcl_read)
|
||||
# TCL returns binary as string; try base64 approach if garbled
|
||||
return raw.encode("latin-1")
|
||||
except Exception:
|
||||
# Fallback: read the local file directly
|
||||
return Path(tmp_path).read_bytes()
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
async def read_to_file(self, bank: int, path: Path) -> None:
|
||||
"""Read an entire flash bank to a local file.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number.
|
||||
path: Destination file path.
|
||||
"""
|
||||
cmd = f"flash read_bank {bank} {path}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"flash read_bank {bank} to {path}")
|
||||
log.info("Flash bank %d read to %s", bank, path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Write operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def write(self, bank: int, offset: int, data: bytes) -> None:
|
||||
"""Write raw bytes to a flash bank at the given offset.
|
||||
|
||||
Writes data through a temporary file since OpenOCD's
|
||||
``flash write_bank`` reads from a file.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number.
|
||||
offset: Byte offset within the bank.
|
||||
data: Bytes to write.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as tmp:
|
||||
tmp.write(data)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
cmd = f"flash write_bank {bank} {tmp_path} {offset}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"flash write_bank {bank}")
|
||||
log.info("Wrote %d bytes to flash bank %d at offset 0x%X", len(data), bank, offset)
|
||||
finally:
|
||||
Path(tmp_path).unlink(missing_ok=True)
|
||||
|
||||
async def write_image(
|
||||
self,
|
||||
path: Path,
|
||||
erase: bool = True,
|
||||
verify: bool = True,
|
||||
) -> None:
|
||||
"""Program a firmware image into flash.
|
||||
|
||||
This is the high-level "flash and go" command that handles erase,
|
||||
write, and optional verification in a single operation.
|
||||
|
||||
Args:
|
||||
path: Path to the firmware image (.bin, .hex, .elf, etc.).
|
||||
erase: Erase affected sectors before writing (default True).
|
||||
verify: Verify flash contents after writing (default True).
|
||||
"""
|
||||
parts = ["flash", "write_image"]
|
||||
if erase:
|
||||
parts.append("erase")
|
||||
parts.append(str(path))
|
||||
|
||||
cmd = " ".join(parts)
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"flash write_image {path}")
|
||||
log.info("Flash image written: %s (erase=%s)", path, erase)
|
||||
|
||||
if verify:
|
||||
verify_cmd = f"verify_image {path}"
|
||||
vresp = await self._conn.send(verify_cmd)
|
||||
if "error" in vresp.lower() or "mismatch" in vresp.lower():
|
||||
raise FlashError(f"Post-write verification failed: {vresp.strip()}")
|
||||
log.info("Flash verification passed for %s", path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Erase operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def erase_sector(self, bank: int, first: int, last: int) -> None:
|
||||
"""Erase a range of sectors within a flash bank.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number.
|
||||
first: First sector number to erase (inclusive).
|
||||
last: Last sector number to erase (inclusive).
|
||||
"""
|
||||
if first > last:
|
||||
raise FlashError(f"Invalid sector range: first ({first}) > last ({last})")
|
||||
|
||||
cmd = f"flash erase_sector {bank} {first} {last}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"flash erase_sector {bank} {first}-{last}")
|
||||
log.info("Erased sectors %d-%d in flash bank %d", first, last, bank)
|
||||
|
||||
async def erase_all(self, bank: int = 0) -> None:
|
||||
"""Erase all sectors in a flash bank.
|
||||
|
||||
Queries the bank info to determine the last sector, then erases
|
||||
the full range.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number (default 0).
|
||||
"""
|
||||
bank_info = await self.info(bank)
|
||||
if not bank_info.sectors:
|
||||
# Fall back to erasing sector 0 through "last" using the count
|
||||
# OpenOCD also accepts "last" as a keyword
|
||||
cmd = f"flash erase_sector {bank} 0 last"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"flash erase_all bank {bank}")
|
||||
else:
|
||||
last_sector = bank_info.sectors[-1].index
|
||||
await self.erase_sector(bank, 0, last_sector)
|
||||
log.info("Erased all sectors in flash bank %d", bank)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Protection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def protect(self, bank: int, first: int, last: int, on: bool) -> None:
|
||||
"""Set or clear write protection on a range of sectors.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number.
|
||||
first: First sector number (inclusive).
|
||||
last: Last sector number (inclusive).
|
||||
on: True to enable protection, False to disable.
|
||||
"""
|
||||
flag = "on" if on else "off"
|
||||
cmd = f"flash protect {bank} {first} {last} {flag}"
|
||||
resp = await self._conn.send(cmd)
|
||||
_check_error(resp, f"flash protect {bank} {first}-{last} {flag}")
|
||||
log.info("Flash bank %d sectors %d-%d protection: %s", bank, first, last, flag)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Verify
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def verify(self, bank: int, path: Path) -> bool:
|
||||
"""Verify flash bank contents against a file.
|
||||
|
||||
Args:
|
||||
bank: Flash bank number.
|
||||
path: Path to the reference binary file.
|
||||
|
||||
Returns:
|
||||
True if the flash contents match the file, False otherwise.
|
||||
"""
|
||||
cmd = f"flash verify_bank {bank} {path}"
|
||||
resp = await self._conn.send(cmd)
|
||||
if "error" in resp.lower() or "mismatch" in resp.lower():
|
||||
log.warning("Flash verify failed for bank %d against %s: %s", bank, path, resp.strip())
|
||||
return False
|
||||
log.info("Flash bank %d verified against %s", bank, path)
|
||||
return True
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Sync wrapper
|
||||
# ======================================================================
|
||||
|
||||
class SyncFlash:
|
||||
"""Synchronous wrapper around Flash for use outside async contexts."""
|
||||
|
||||
def __init__(self, flash: Flash, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._flash = flash
|
||||
self._loop = loop
|
||||
|
||||
def banks(self) -> list[FlashBank]:
|
||||
return self._loop.run_until_complete(self._flash.banks())
|
||||
|
||||
def info(self, bank: int = 0) -> FlashBank:
|
||||
return self._loop.run_until_complete(self._flash.info(bank))
|
||||
|
||||
def read(self, bank: int, offset: int, size: int) -> bytes:
|
||||
return self._loop.run_until_complete(self._flash.read(bank, offset, size))
|
||||
|
||||
def read_to_file(self, bank: int, path: Path) -> None:
|
||||
self._loop.run_until_complete(self._flash.read_to_file(bank, path))
|
||||
|
||||
def write(self, bank: int, offset: int, data: bytes) -> None:
|
||||
self._loop.run_until_complete(self._flash.write(bank, offset, data))
|
||||
|
||||
def write_image(self, path: Path, erase: bool = True, verify: bool = True) -> None:
|
||||
self._loop.run_until_complete(self._flash.write_image(path, erase=erase, verify=verify))
|
||||
|
||||
def erase_sector(self, bank: int, first: int, last: int) -> None:
|
||||
self._loop.run_until_complete(self._flash.erase_sector(bank, first, last))
|
||||
|
||||
def erase_all(self, bank: int = 0) -> None:
|
||||
self._loop.run_until_complete(self._flash.erase_all(bank))
|
||||
|
||||
def protect(self, bank: int, first: int, last: int, on: bool) -> None:
|
||||
self._loop.run_until_complete(self._flash.protect(bank, first, last, on=on))
|
||||
|
||||
def verify(self, bank: int, path: Path) -> bool:
|
||||
return self._loop.run_until_complete(self._flash.verify(bank, path))
|
||||
5
src/openocd/jtag/__init__.py
Normal file
5
src/openocd/jtag/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""JTAG chain control, scan operations, and boundary scan."""
|
||||
|
||||
from openocd.jtag.chain import JTAGController, SyncJTAGController
|
||||
|
||||
__all__ = ["JTAGController", "SyncJTAGController"]
|
||||
52
src/openocd/jtag/boundary.py
Normal file
52
src/openocd/jtag/boundary.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""SVF and XSVF boundary-scan file execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import JTAGError
|
||||
|
||||
|
||||
async def svf(
|
||||
conn: Connection,
|
||||
path: Path,
|
||||
*,
|
||||
tap: str | None = None,
|
||||
quiet: bool = False,
|
||||
progress: bool = True,
|
||||
) -> None:
|
||||
"""Execute an SVF (Serial Vector Format) file.
|
||||
|
||||
Args:
|
||||
path: Path to the ``.svf`` file.
|
||||
tap: Restrict operations to this TAP. When ``None``, OpenOCD
|
||||
applies vectors to whatever TAP is appropriate.
|
||||
quiet: Suppress per-statement logging inside OpenOCD.
|
||||
progress: Show a progress indicator (default on).
|
||||
"""
|
||||
parts = ["svf", str(path)]
|
||||
if tap is not None:
|
||||
parts.extend(["-tap", tap])
|
||||
if quiet:
|
||||
parts.append("quiet")
|
||||
if progress:
|
||||
parts.append("progress")
|
||||
resp = await conn.send(" ".join(parts))
|
||||
_check_error(resp, "svf")
|
||||
|
||||
|
||||
async def xsvf(conn: Connection, tap: str, path: Path) -> None:
|
||||
"""Execute an XSVF file against the given TAP.
|
||||
|
||||
Args:
|
||||
tap: TAP to target (e.g. ``stm32f1x.cpu``).
|
||||
path: Path to the ``.xsvf`` file.
|
||||
"""
|
||||
resp = await conn.send(f"xsvf {tap} {path}")
|
||||
_check_error(resp, "xsvf")
|
||||
|
||||
|
||||
def _check_error(response: str, command: str) -> None:
|
||||
if "Error" in response or "error" in response.split("\n")[0]:
|
||||
raise JTAGError(f"{command} failed: {response.strip()}")
|
||||
218
src/openocd/jtag/chain.py
Normal file
218
src/openocd/jtag/chain.py
Normal file
@ -0,0 +1,218 @@
|
||||
"""JTAG chain enumeration and the main JTAGController facade."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import JTAGError
|
||||
from openocd.jtag import boundary as _boundary
|
||||
from openocd.jtag import scan as _scan
|
||||
from openocd.jtag import state as _state
|
||||
from openocd.types import JTAGState, TAPInfo
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Regex for one row of ``scan_chain`` output.
|
||||
# Example line:
|
||||
# 0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f
|
||||
_CHAIN_ROW_RE = re.compile(
|
||||
r"^\s*\d+\s+" # index
|
||||
r"(\S+)\s+" # tap name (chip.tap)
|
||||
r"([YN])\s+" # enabled
|
||||
r"(0x[0-9a-fA-F]+)\s+" # idcode
|
||||
r"(0x[0-9a-fA-F]+)\s+" # expected
|
||||
r"(\d+)", # ir_length
|
||||
)
|
||||
|
||||
|
||||
async def scan_chain(conn: Connection) -> list[TAPInfo]:
|
||||
"""Query the JTAG scan chain and return a list of discovered TAPs."""
|
||||
resp = await conn.send("scan_chain")
|
||||
if "Error" in resp:
|
||||
raise JTAGError(f"scan_chain failed: {resp.strip()}")
|
||||
return _parse_scan_chain(resp)
|
||||
|
||||
|
||||
async def new_tap(
|
||||
conn: Connection,
|
||||
chip: str,
|
||||
tap: str,
|
||||
ir_len: int,
|
||||
expected_id: int | None = None,
|
||||
) -> None:
|
||||
"""Declare a new TAP on the JTAG chain.
|
||||
|
||||
Args:
|
||||
chip: Chip name (first part of the dotted TAP name).
|
||||
tap: TAP name (second part, e.g. ``cpu``, ``bs``).
|
||||
ir_len: Instruction register length in bits.
|
||||
expected_id: Expected IDCODE. When ``None``, OpenOCD skips
|
||||
the IDCODE verification.
|
||||
"""
|
||||
parts = ["jtag", "newtap", chip, tap, "-irlen", str(ir_len)]
|
||||
if expected_id is not None:
|
||||
parts.extend(["-expected-id", f"0x{expected_id:08x}"])
|
||||
resp = await conn.send(" ".join(parts))
|
||||
if "Error" in resp:
|
||||
raise JTAGError(f"newtap failed: {resp.strip()}")
|
||||
|
||||
|
||||
def _parse_scan_chain(raw: str) -> list[TAPInfo]:
|
||||
"""Parse the tabular output of ``scan_chain``."""
|
||||
taps: list[TAPInfo] = []
|
||||
for line in raw.splitlines():
|
||||
m = _CHAIN_ROW_RE.match(line)
|
||||
if not m:
|
||||
continue
|
||||
full_name = m.group(1)
|
||||
# Split "chip.tap" into components
|
||||
dot = full_name.find(".")
|
||||
if dot == -1:
|
||||
chip_part, tap_part = full_name, ""
|
||||
else:
|
||||
chip_part, tap_part = full_name[:dot], full_name[dot + 1 :]
|
||||
|
||||
taps.append(
|
||||
TAPInfo(
|
||||
name=full_name,
|
||||
chip=chip_part,
|
||||
tap_name=tap_part,
|
||||
idcode=int(m.group(3), 16),
|
||||
ir_length=int(m.group(5)),
|
||||
enabled=m.group(2) == "Y",
|
||||
)
|
||||
)
|
||||
return taps
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# JTAGController — unified facade
|
||||
# ======================================================================
|
||||
|
||||
class JTAGController:
|
||||
"""High-level async interface to all JTAG operations.
|
||||
|
||||
Delegates to helper functions in the ``scan``, ``state``, and
|
||||
``boundary`` sub-modules so each method stays concise.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# -- Chain enumeration -------------------------------------------------
|
||||
|
||||
async def scan_chain(self) -> list[TAPInfo]:
|
||||
"""Return every TAP discovered on the JTAG chain."""
|
||||
return await scan_chain(self._conn)
|
||||
|
||||
async def new_tap(
|
||||
self,
|
||||
chip: str,
|
||||
tap: str,
|
||||
ir_len: int,
|
||||
expected_id: int | None = None,
|
||||
) -> None:
|
||||
"""Declare a new TAP on the chain."""
|
||||
await new_tap(self._conn, chip, tap, ir_len, expected_id)
|
||||
|
||||
# -- Scan operations ---------------------------------------------------
|
||||
|
||||
async def irscan(self, tap: str, instruction: int) -> int:
|
||||
"""Shift *instruction* into the TAP instruction register."""
|
||||
return await _scan.irscan(self._conn, tap, instruction)
|
||||
|
||||
async def drscan(self, tap: str, bits: int, value: int) -> int:
|
||||
"""Shift *value* (of width *bits*) through the data register."""
|
||||
return await _scan.drscan(self._conn, tap, bits, value)
|
||||
|
||||
async def runtest(self, cycles: int) -> None:
|
||||
"""Clock *cycles* TCK pulses in the Run-Test/Idle state."""
|
||||
await _scan.runtest(self._conn, cycles)
|
||||
|
||||
# -- TAP state machine -------------------------------------------------
|
||||
|
||||
async def pathmove(self, states: list[JTAGState]) -> None:
|
||||
"""Walk the TAP controller through an explicit state sequence."""
|
||||
await _state.pathmove(self._conn, states)
|
||||
|
||||
# -- Boundary scan (SVF / XSVF) ----------------------------------------
|
||||
|
||||
async def svf(
|
||||
self,
|
||||
path: Path,
|
||||
tap: str | None = None,
|
||||
*,
|
||||
quiet: bool = False,
|
||||
progress: bool = True,
|
||||
) -> None:
|
||||
"""Execute an SVF boundary-scan file."""
|
||||
await _boundary.svf(self._conn, path, tap=tap, quiet=quiet, progress=progress)
|
||||
|
||||
async def xsvf(self, tap: str, path: Path) -> None:
|
||||
"""Execute an XSVF boundary-scan file against *tap*."""
|
||||
await _boundary.xsvf(self._conn, tap, path)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# SyncJTAGController — blocking wrappers
|
||||
# ======================================================================
|
||||
|
||||
class SyncJTAGController:
|
||||
"""Synchronous wrapper around :class:`JTAGController`.
|
||||
|
||||
Every async method is exposed with the same signature but runs
|
||||
through ``loop.run_until_complete``.
|
||||
"""
|
||||
|
||||
def __init__(self, ctrl: JTAGController, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._ctrl = ctrl
|
||||
self._loop = loop
|
||||
|
||||
# -- Chain -------------------------------------------------------------
|
||||
|
||||
def scan_chain(self) -> list[TAPInfo]:
|
||||
return self._loop.run_until_complete(self._ctrl.scan_chain())
|
||||
|
||||
def new_tap(
|
||||
self,
|
||||
chip: str,
|
||||
tap: str,
|
||||
ir_len: int,
|
||||
expected_id: int | None = None,
|
||||
) -> None:
|
||||
self._loop.run_until_complete(self._ctrl.new_tap(chip, tap, ir_len, expected_id))
|
||||
|
||||
# -- Scan --------------------------------------------------------------
|
||||
|
||||
def irscan(self, tap: str, instruction: int) -> int:
|
||||
return self._loop.run_until_complete(self._ctrl.irscan(tap, instruction))
|
||||
|
||||
def drscan(self, tap: str, bits: int, value: int) -> int:
|
||||
return self._loop.run_until_complete(self._ctrl.drscan(tap, bits, value))
|
||||
|
||||
def runtest(self, cycles: int) -> None:
|
||||
self._loop.run_until_complete(self._ctrl.runtest(cycles))
|
||||
|
||||
# -- State -------------------------------------------------------------
|
||||
|
||||
def pathmove(self, states: list[JTAGState]) -> None:
|
||||
self._loop.run_until_complete(self._ctrl.pathmove(states))
|
||||
|
||||
# -- Boundary ----------------------------------------------------------
|
||||
|
||||
def svf(
|
||||
self,
|
||||
path: Path,
|
||||
tap: str | None = None,
|
||||
*,
|
||||
quiet: bool = False,
|
||||
progress: bool = True,
|
||||
) -> None:
|
||||
self._loop.run_until_complete(self._ctrl.svf(path, tap, quiet=quiet, progress=progress))
|
||||
|
||||
def xsvf(self, tap: str, path: Path) -> None:
|
||||
self._loop.run_until_complete(self._ctrl.xsvf(tap, path))
|
||||
58
src/openocd/jtag/scan.py
Normal file
58
src/openocd/jtag/scan.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""IR/DR scan and TCK run-test operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import JTAGError
|
||||
|
||||
|
||||
async def irscan(conn: Connection, tap: str, instruction: int) -> int:
|
||||
"""Shift an instruction into the TAP instruction register.
|
||||
|
||||
Returns the value shifted out of the IR during the operation.
|
||||
"""
|
||||
resp = await conn.send(f"irscan {tap} 0x{instruction:x}")
|
||||
_check_error(resp, "irscan")
|
||||
# OpenOCD returns the shifted-out value as a hex string
|
||||
cleaned = resp.strip()
|
||||
if not cleaned:
|
||||
return 0
|
||||
try:
|
||||
return int(cleaned, 16)
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
|
||||
async def drscan(conn: Connection, tap: str, bits: int, value: int) -> int:
|
||||
"""Shift data through the TAP data register.
|
||||
|
||||
Args:
|
||||
tap: TAP name (e.g. ``stm32f1x.cpu``).
|
||||
bits: Number of bits to shift.
|
||||
value: Value to shift in.
|
||||
|
||||
Returns the value shifted out of the DR.
|
||||
"""
|
||||
resp = await conn.send(f"drscan {tap} {bits} 0x{value:x}")
|
||||
_check_error(resp, "drscan")
|
||||
cleaned = resp.strip()
|
||||
if not cleaned:
|
||||
return 0
|
||||
try:
|
||||
return int(cleaned, 16)
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
|
||||
async def runtest(conn: Connection, cycles: int) -> None:
|
||||
"""Execute *cycles* TCK clocks in the Run-Test/Idle state."""
|
||||
if cycles < 0:
|
||||
raise JTAGError(f"runtest cycles must be non-negative, got {cycles}")
|
||||
resp = await conn.send(f"runtest {cycles}")
|
||||
_check_error(resp, "runtest")
|
||||
|
||||
|
||||
def _check_error(response: str, command: str) -> None:
|
||||
"""Raise JTAGError if OpenOCD reported an error."""
|
||||
if "Error" in response or "error" in response.split("\n")[0]:
|
||||
raise JTAGError(f"{command} failed: {response.strip()}")
|
||||
26
src/openocd/jtag/state.py
Normal file
26
src/openocd/jtag/state.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""TAP state-machine path movement."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.errors import JTAGError
|
||||
from openocd.types import JTAGState
|
||||
|
||||
|
||||
async def pathmove(conn: Connection, states: list[JTAGState]) -> None:
|
||||
"""Move the TAP controller through a sequence of states.
|
||||
|
||||
Each state must be a legal single-step transition from the previous one
|
||||
in the IEEE 1149.1 state machine. OpenOCD validates the path and will
|
||||
report an error for illegal transitions.
|
||||
"""
|
||||
if not states:
|
||||
raise JTAGError("pathmove requires at least one target state")
|
||||
state_names = " ".join(s.value for s in states)
|
||||
resp = await conn.send(f"pathmove {state_names}")
|
||||
_check_error(resp, "pathmove")
|
||||
|
||||
|
||||
def _check_error(response: str, command: str) -> None:
|
||||
if "Error" in response or "error" in response.split("\n")[0]:
|
||||
raise JTAGError(f"{command} failed: {response.strip()}")
|
||||
243
src/openocd/memory.py
Normal file
243
src/openocd/memory.py
Normal file
@ -0,0 +1,243 @@
|
||||
"""Memory read/write operations via OpenOCD TCL API.
|
||||
|
||||
Uses the ``read_memory`` and ``write_memory`` TCL commands for reliable
|
||||
structured I/O, falling back to ``mdb``/``mdw`` style commands only
|
||||
where the TCL API is unavailable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from openocd.connection.tcl_rpc import TclRpcConnection
|
||||
from openocd.errors import TargetError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Width constants for read_memory / write_memory
|
||||
_WIDTH_8 = 8
|
||||
_WIDTH_16 = 16
|
||||
_WIDTH_32 = 32
|
||||
_WIDTH_64 = 64
|
||||
|
||||
# Hexdump formatting
|
||||
_HEXDUMP_BYTES_PER_LINE = 16
|
||||
|
||||
|
||||
class Memory:
|
||||
"""Read and write target memory."""
|
||||
|
||||
def __init__(self, conn: TclRpcConnection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Typed reads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def read_u8(self, addr: int, count: int = 1) -> list[int]:
|
||||
"""Read 8-bit values starting at *addr*."""
|
||||
return await self._read(addr, _WIDTH_8, count)
|
||||
|
||||
async def read_u16(self, addr: int, count: int = 1) -> list[int]:
|
||||
"""Read 16-bit values starting at *addr*."""
|
||||
return await self._read(addr, _WIDTH_16, count)
|
||||
|
||||
async def read_u32(self, addr: int, count: int = 1) -> list[int]:
|
||||
"""Read 32-bit values starting at *addr*."""
|
||||
return await self._read(addr, _WIDTH_32, count)
|
||||
|
||||
async def read_u64(self, addr: int, count: int = 1) -> list[int]:
|
||||
"""Read 64-bit values starting at *addr*."""
|
||||
return await self._read(addr, _WIDTH_64, count)
|
||||
|
||||
async def read_bytes(self, addr: int, size: int) -> bytes:
|
||||
"""Read *size* bytes starting at *addr* and return as a bytes object."""
|
||||
values = await self._read(addr, _WIDTH_8, size)
|
||||
return bytes(values)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Typed writes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def write_u8(self, addr: int, values: int | list[int]) -> None:
|
||||
"""Write one or more 8-bit values starting at *addr*."""
|
||||
await self._write(addr, _WIDTH_8, values)
|
||||
|
||||
async def write_u16(self, addr: int, values: int | list[int]) -> None:
|
||||
"""Write one or more 16-bit values starting at *addr*."""
|
||||
await self._write(addr, _WIDTH_16, values)
|
||||
|
||||
async def write_u32(self, addr: int, values: int | list[int]) -> None:
|
||||
"""Write one or more 32-bit values starting at *addr*."""
|
||||
await self._write(addr, _WIDTH_32, values)
|
||||
|
||||
async def write_bytes(self, addr: int, data: bytes) -> None:
|
||||
"""Write raw bytes to memory starting at *addr*."""
|
||||
await self._write(addr, _WIDTH_8, list(data))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Utilities
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def search(self, pattern: bytes, start: int, end: int) -> list[int]:
|
||||
"""Search for *pattern* in memory between *start* and *end*.
|
||||
|
||||
Reads the region in chunks and returns a list of addresses where
|
||||
the pattern was found. This is done client-side since OpenOCD
|
||||
has no native memory-search command.
|
||||
"""
|
||||
if not pattern:
|
||||
return []
|
||||
|
||||
region_size = end - start
|
||||
if region_size <= 0:
|
||||
return []
|
||||
|
||||
chunk_size = 4096
|
||||
overlap = len(pattern) - 1
|
||||
results: list[int] = []
|
||||
offset = 0
|
||||
|
||||
while offset < region_size:
|
||||
read_len = min(chunk_size + overlap, region_size - offset)
|
||||
data = await self.read_bytes(start + offset, read_len)
|
||||
|
||||
# Scan for the pattern within this chunk
|
||||
search_start = 0
|
||||
while True:
|
||||
idx = data.find(pattern, search_start)
|
||||
if idx == -1:
|
||||
break
|
||||
results.append(start + offset + idx)
|
||||
search_start = idx + 1
|
||||
|
||||
# Advance past the non-overlapping portion
|
||||
offset += chunk_size
|
||||
|
||||
return results
|
||||
|
||||
async def dump(self, addr: int, size: int, path: Path) -> None:
|
||||
"""Read *size* bytes from *addr* and write them to a file."""
|
||||
data = await self.read_bytes(addr, size)
|
||||
path.write_bytes(data)
|
||||
log.info("Dumped %d bytes from 0x%08X to %s", size, addr, path)
|
||||
|
||||
async def hexdump(self, addr: int, size: int) -> str:
|
||||
"""Read *size* bytes and return a formatted hex+ASCII dump.
|
||||
|
||||
Output format (16 bytes per line)::
|
||||
|
||||
08000000: 00 50 00 20 A1 01 00 08 AB 01 00 08 AD 01 00 08 |.P. ............|
|
||||
"""
|
||||
data = await self.read_bytes(addr, size)
|
||||
lines: list[str] = []
|
||||
|
||||
for offset in range(0, len(data), _HEXDUMP_BYTES_PER_LINE):
|
||||
chunk = data[offset : offset + _HEXDUMP_BYTES_PER_LINE]
|
||||
line_addr = addr + offset
|
||||
|
||||
# Hex portion — two groups of 8 bytes separated by an extra space
|
||||
hex_parts: list[str] = []
|
||||
for i, b in enumerate(chunk):
|
||||
hex_parts.append(f"{b:02X}")
|
||||
if i == 7:
|
||||
hex_parts.append("") # extra gap between byte 7 and 8
|
||||
hex_str = " ".join(hex_parts)
|
||||
|
||||
# Pad to consistent width (3 chars * 16 bytes + 1 extra gap = 49 chars)
|
||||
# 16 hex pairs = 16*2=32 hex chars, 15 spaces + 1 gap space = 16 = 49
|
||||
hex_str = hex_str.ljust(49)
|
||||
|
||||
# ASCII portion
|
||||
ascii_str = "".join(
|
||||
chr(b) if 0x20 <= b < 0x7F else "." for b in chunk
|
||||
)
|
||||
|
||||
lines.append(f"{line_addr:08X}: {hex_str} |{ascii_str}|")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _read(self, addr: int, width: int, count: int) -> list[int]:
|
||||
"""Read *count* values of *width* bits using the TCL ``read_memory`` API.
|
||||
|
||||
Command: ``read_memory <addr> <width> <count>``
|
||||
Response: space-separated hex values.
|
||||
"""
|
||||
cmd = f"read_memory 0x{addr:x} {width} {count}"
|
||||
resp = await self._conn.send(cmd)
|
||||
|
||||
if "error" in resp.lower():
|
||||
raise TargetError(f"read_memory failed: {resp}")
|
||||
|
||||
tokens = resp.strip().split()
|
||||
try:
|
||||
return [int(t, 16) for t in tokens]
|
||||
except ValueError as exc:
|
||||
raise TargetError(
|
||||
f"Cannot parse read_memory response: {resp!r}"
|
||||
) from exc
|
||||
|
||||
async def _write(self, addr: int, width: int, values: int | list[int]) -> None:
|
||||
"""Write values of *width* bits using the TCL ``write_memory`` API.
|
||||
|
||||
Command: ``write_memory <addr> <width> {val1 val2 ...}``
|
||||
"""
|
||||
if isinstance(values, int):
|
||||
values = [values]
|
||||
|
||||
val_str = " ".join(f"0x{v:x}" for v in values)
|
||||
cmd = f"write_memory 0x{addr:x} {width} {{{val_str}}}"
|
||||
resp = await self._conn.send(cmd)
|
||||
|
||||
if "error" in resp.lower():
|
||||
raise TargetError(f"write_memory failed: {resp}")
|
||||
|
||||
|
||||
class SyncMemory:
|
||||
"""Synchronous wrapper around Memory."""
|
||||
|
||||
def __init__(self, memory: Memory, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._memory = memory
|
||||
self._loop = loop
|
||||
|
||||
def read_u8(self, addr: int, count: int = 1) -> list[int]:
|
||||
return self._loop.run_until_complete(self._memory.read_u8(addr, count))
|
||||
|
||||
def read_u16(self, addr: int, count: int = 1) -> list[int]:
|
||||
return self._loop.run_until_complete(self._memory.read_u16(addr, count))
|
||||
|
||||
def read_u32(self, addr: int, count: int = 1) -> list[int]:
|
||||
return self._loop.run_until_complete(self._memory.read_u32(addr, count))
|
||||
|
||||
def read_u64(self, addr: int, count: int = 1) -> list[int]:
|
||||
return self._loop.run_until_complete(self._memory.read_u64(addr, count))
|
||||
|
||||
def read_bytes(self, addr: int, size: int) -> bytes:
|
||||
return self._loop.run_until_complete(self._memory.read_bytes(addr, size))
|
||||
|
||||
def write_u8(self, addr: int, values: int | list[int]) -> None:
|
||||
self._loop.run_until_complete(self._memory.write_u8(addr, values))
|
||||
|
||||
def write_u16(self, addr: int, values: int | list[int]) -> None:
|
||||
self._loop.run_until_complete(self._memory.write_u16(addr, values))
|
||||
|
||||
def write_u32(self, addr: int, values: int | list[int]) -> None:
|
||||
self._loop.run_until_complete(self._memory.write_u32(addr, values))
|
||||
|
||||
def write_bytes(self, addr: int, data: bytes) -> None:
|
||||
self._loop.run_until_complete(self._memory.write_bytes(addr, data))
|
||||
|
||||
def search(self, pattern: bytes, start: int, end: int) -> list[int]:
|
||||
return self._loop.run_until_complete(self._memory.search(pattern, start, end))
|
||||
|
||||
def dump(self, addr: int, size: int, path: Path) -> None:
|
||||
self._loop.run_until_complete(self._memory.dump(addr, size, path))
|
||||
|
||||
def hexdump(self, addr: int, size: int) -> str:
|
||||
return self._loop.run_until_complete(self._memory.hexdump(addr, size))
|
||||
137
src/openocd/process.py
Normal file
137
src/openocd/process.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""OpenOCD subprocess management.
|
||||
|
||||
Spawns an OpenOCD process, waits for the TCL RPC port to become
|
||||
available, and provides clean shutdown.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
from openocd.errors import ProcessError
|
||||
from openocd.errors import TimeoutError as OpenOCDTimeoutError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TCL_PORT = 6666
|
||||
READY_POLL_INTERVAL = 0.25
|
||||
|
||||
|
||||
class OpenOCDProcess:
|
||||
"""Spawn and manage an OpenOCD subprocess."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._proc: asyncio.subprocess.Process | None = None
|
||||
self._tcl_port: int = DEFAULT_TCL_PORT
|
||||
|
||||
@property
|
||||
def pid(self) -> int | None:
|
||||
return self._proc.pid if self._proc else None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._proc is not None and self._proc.returncode is None
|
||||
|
||||
@property
|
||||
def tcl_port(self) -> int:
|
||||
return self._tcl_port
|
||||
|
||||
async def start(
|
||||
self,
|
||||
config: str,
|
||||
extra_args: list[str] | None = None,
|
||||
tcl_port: int = DEFAULT_TCL_PORT,
|
||||
openocd_bin: str | None = None,
|
||||
) -> None:
|
||||
"""Start OpenOCD with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Config file path or inline ``-f`` / ``-c`` arguments.
|
||||
Multiple files can be separated by spaces with ``-f`` prefixes,
|
||||
e.g. ``"interface/cmsis-dap.cfg -f target/stm32f1x.cfg"``.
|
||||
extra_args: Additional CLI arguments.
|
||||
tcl_port: TCL RPC port (default 6666).
|
||||
openocd_bin: Path to OpenOCD binary (auto-detected if None).
|
||||
"""
|
||||
self._tcl_port = tcl_port
|
||||
binary = openocd_bin or shutil.which("openocd")
|
||||
if not binary:
|
||||
raise ProcessError(
|
||||
"OpenOCD binary not found. Install it or pass openocd_bin="
|
||||
)
|
||||
|
||||
args = [binary]
|
||||
# Parse the config string — support both bare paths and -f/-c flags
|
||||
config_parts = config.split()
|
||||
i = 0
|
||||
while i < len(config_parts):
|
||||
part = config_parts[i]
|
||||
if part in ("-f", "-c"):
|
||||
args.extend([part, config_parts[i + 1]])
|
||||
i += 2
|
||||
else:
|
||||
args.extend(["-f", part])
|
||||
i += 1
|
||||
|
||||
args.extend(["-c", f"tcl_port {tcl_port}"])
|
||||
|
||||
if extra_args:
|
||||
args.extend(extra_args)
|
||||
|
||||
log.info("Starting OpenOCD: %s", " ".join(args))
|
||||
try:
|
||||
self._proc = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise ProcessError(f"OpenOCD binary not found at {binary}") from exc
|
||||
except OSError as exc:
|
||||
raise ProcessError(f"Failed to start OpenOCD: {exc}") from exc
|
||||
|
||||
async def wait_ready(self, timeout: float = 10.0) -> None:
|
||||
"""Poll until the TCL RPC port is accepting connections."""
|
||||
deadline = asyncio.get_event_loop().time() + timeout
|
||||
while asyncio.get_event_loop().time() < deadline:
|
||||
# Check if the process died
|
||||
if self._proc and self._proc.returncode is not None:
|
||||
stderr = ""
|
||||
if self._proc.stderr:
|
||||
raw = await self._proc.stderr.read()
|
||||
stderr = raw.decode("utf-8", errors="replace")
|
||||
raise ProcessError(
|
||||
f"OpenOCD exited with code {self._proc.returncode}: {stderr[-500:]}"
|
||||
)
|
||||
|
||||
try:
|
||||
_, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection("localhost", self._tcl_port),
|
||||
timeout=1.0,
|
||||
)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
log.info("OpenOCD ready on TCL port %d", self._tcl_port)
|
||||
return
|
||||
except (OSError, TimeoutError):
|
||||
await asyncio.sleep(READY_POLL_INTERVAL)
|
||||
|
||||
raise OpenOCDTimeoutError(
|
||||
f"OpenOCD did not become ready within {timeout}s"
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Terminate the OpenOCD process."""
|
||||
if not self._proc:
|
||||
return
|
||||
if self._proc.returncode is None:
|
||||
self._proc.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self._proc.wait(), timeout=5.0)
|
||||
except TimeoutError:
|
||||
self._proc.kill()
|
||||
await self._proc.wait()
|
||||
log.info("OpenOCD process stopped (pid=%d)", self._proc.pid)
|
||||
self._proc = None
|
||||
186
src/openocd/registers.py
Normal file
186
src/openocd/registers.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""CPU register access via OpenOCD.
|
||||
|
||||
Wraps the ``reg`` command family to read and write individual registers,
|
||||
list all registers, and provide ARM Cortex-M convenience accessors.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
|
||||
from openocd.connection.tcl_rpc import TclRpcConnection
|
||||
from openocd.errors import TargetError, TargetNotHaltedError
|
||||
from openocd.types import Register
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Matches "reg <name>" output: "pc (/32): 0x08001234"
|
||||
_REG_VALUE_RE = re.compile(
|
||||
r"(\S+)\s+\(/(\d+)\):\s*(0x[0-9a-fA-F]+)"
|
||||
)
|
||||
|
||||
# Matches a row in "reg" (list all) output.
|
||||
# Typical formats:
|
||||
# "(0) r0 (/32): 0x00000000"
|
||||
# "(123) xPSR (/32): 0x61000000 (dirty)"
|
||||
_REG_LIST_RE = re.compile(
|
||||
r"\((\d+)\)\s+" # register number
|
||||
r"(\S+)\s+" # register name
|
||||
r"\(/(\d+)\):\s*" # bit width
|
||||
r"(0x[0-9a-fA-F]+)" # value
|
||||
r"(?:\s+\(dirty\))?" # optional dirty flag
|
||||
)
|
||||
|
||||
|
||||
class Registers:
|
||||
"""Read and write CPU registers."""
|
||||
|
||||
def __init__(self, conn: TclRpcConnection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
async def read(self, name: str) -> int:
|
||||
"""Read a single register by name and return its value.
|
||||
|
||||
Args:
|
||||
name: Register name, e.g. ``"pc"``, ``"r0"``, ``"xPSR"``.
|
||||
|
||||
Returns:
|
||||
The register value as an integer.
|
||||
|
||||
Raises:
|
||||
TargetNotHaltedError: Target must be halted for register access.
|
||||
TargetError: Register not found or command failed.
|
||||
"""
|
||||
resp = await self._conn.send(f"reg {name}")
|
||||
self._check_halted(resp)
|
||||
|
||||
m = _REG_VALUE_RE.search(resp)
|
||||
if not m:
|
||||
raise TargetError(f"Cannot parse register '{name}' from: {resp}")
|
||||
|
||||
return int(m.group(3), 16)
|
||||
|
||||
async def write(self, name: str, value: int) -> None:
|
||||
"""Write a value to a register.
|
||||
|
||||
Args:
|
||||
name: Register name.
|
||||
value: Value to write.
|
||||
"""
|
||||
resp = await self._conn.send(f"reg {name} 0x{value:x}")
|
||||
self._check_halted(resp)
|
||||
|
||||
if "error" in resp.lower() and "not halted" not in resp.lower():
|
||||
raise TargetError(f"reg write failed: {resp}")
|
||||
|
||||
async def read_all(self) -> dict[str, Register]:
|
||||
"""Read all registers and return them as a dict keyed by name.
|
||||
|
||||
Returns:
|
||||
Mapping of register name to Register dataclass.
|
||||
"""
|
||||
resp = await self._conn.send("reg")
|
||||
self._check_halted(resp)
|
||||
|
||||
registers: dict[str, Register] = {}
|
||||
for line in resp.splitlines():
|
||||
m = _REG_LIST_RE.search(line)
|
||||
if m:
|
||||
number = int(m.group(1))
|
||||
name = m.group(2)
|
||||
size = int(m.group(3))
|
||||
value = int(m.group(4), 16)
|
||||
dirty = "(dirty)" in line
|
||||
|
||||
registers[name] = Register(
|
||||
name=name,
|
||||
number=number,
|
||||
value=value,
|
||||
size=size,
|
||||
dirty=dirty,
|
||||
)
|
||||
|
||||
return registers
|
||||
|
||||
async def read_many(self, names: list[str]) -> dict[str, int]:
|
||||
"""Read several registers by name.
|
||||
|
||||
Args:
|
||||
names: List of register names.
|
||||
|
||||
Returns:
|
||||
Mapping of register name to value.
|
||||
"""
|
||||
results: dict[str, int] = {}
|
||||
for name in names:
|
||||
results[name] = await self.read(name)
|
||||
return results
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ARM Cortex-M convenience accessors
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def pc(self) -> int:
|
||||
"""Read the program counter."""
|
||||
return await self.read("pc")
|
||||
|
||||
async def sp(self) -> int:
|
||||
"""Read the stack pointer."""
|
||||
return await self.read("sp")
|
||||
|
||||
async def lr(self) -> int:
|
||||
"""Read the link register."""
|
||||
return await self.read("lr")
|
||||
|
||||
async def xpsr(self) -> int:
|
||||
"""Read the xPSR (combined program status register)."""
|
||||
return await self.read("xPSR")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _check_halted(resp: str) -> None:
|
||||
"""Raise TargetNotHaltedError if the response indicates the target
|
||||
is not halted (register access requires a halted target).
|
||||
"""
|
||||
lower = resp.lower()
|
||||
if "not halted" in lower or "target not halted" in lower:
|
||||
raise TargetNotHaltedError(
|
||||
"Target must be halted to access registers"
|
||||
)
|
||||
|
||||
|
||||
class SyncRegisters:
|
||||
"""Synchronous wrapper around Registers."""
|
||||
|
||||
def __init__(self, registers: Registers, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._registers = registers
|
||||
self._loop = loop
|
||||
|
||||
def read(self, name: str) -> int:
|
||||
return self._loop.run_until_complete(self._registers.read(name))
|
||||
|
||||
def write(self, name: str, value: int) -> None:
|
||||
self._loop.run_until_complete(self._registers.write(name, value))
|
||||
|
||||
def read_all(self) -> dict[str, Register]:
|
||||
return self._loop.run_until_complete(self._registers.read_all())
|
||||
|
||||
def read_many(self, names: list[str]) -> dict[str, int]:
|
||||
return self._loop.run_until_complete(self._registers.read_many(names))
|
||||
|
||||
def pc(self) -> int:
|
||||
return self._loop.run_until_complete(self._registers.pc())
|
||||
|
||||
def sp(self) -> int:
|
||||
return self._loop.run_until_complete(self._registers.sp())
|
||||
|
||||
def lr(self) -> int:
|
||||
return self._loop.run_until_complete(self._registers.lr())
|
||||
|
||||
def xpsr(self) -> int:
|
||||
return self._loop.run_until_complete(self._registers.xpsr())
|
||||
226
src/openocd/rtt.py
Normal file
226
src/openocd/rtt.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""Real-Time Transfer (RTT) support via OpenOCD.
|
||||
|
||||
SEGGER RTT provides high-speed bidirectional communication between
|
||||
a debug host and an embedded target using shared memory in RAM.
|
||||
OpenOCD exposes RTT through its ``rtt`` command family.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openocd.errors import OpenOCDError
|
||||
from openocd.types import RTTChannel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openocd.connection.base import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RTTManager:
|
||||
"""Control and use SEGGER RTT channels via OpenOCD.
|
||||
|
||||
Typical flow::
|
||||
|
||||
rtt = RTTManager(conn)
|
||||
await rtt.setup(address=0x20000000, size=0x1000)
|
||||
await rtt.start()
|
||||
channels = await rtt.channels()
|
||||
data = await rtt.read(0)
|
||||
await rtt.write(0, "hello\\n")
|
||||
await rtt.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
async def setup(
|
||||
self,
|
||||
address: int,
|
||||
size: int,
|
||||
id_string: str = "SEGGER RTT",
|
||||
) -> None:
|
||||
"""Configure RTT control block search parameters.
|
||||
|
||||
Args:
|
||||
address: Start address of the RAM region to search.
|
||||
size: Size of the search region in bytes.
|
||||
id_string: RTT control block identifier (default "SEGGER RTT").
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the setup command fails.
|
||||
"""
|
||||
cmd = f'rtt setup 0x{address:X} 0x{size:X} "{id_string}"'
|
||||
response = await self._conn.send(cmd)
|
||||
_check_rtt_response(response, cmd)
|
||||
log.info(
|
||||
"RTT setup: search 0x%08X +0x%X id=%r",
|
||||
address,
|
||||
size,
|
||||
id_string,
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start RTT — searches for the control block and activates channels.
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the control block is not found or start fails.
|
||||
"""
|
||||
response = await self._conn.send("rtt start")
|
||||
_check_rtt_response(response, "rtt start")
|
||||
log.info("RTT started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop RTT communication.
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the stop command fails.
|
||||
"""
|
||||
response = await self._conn.send("rtt stop")
|
||||
_check_rtt_response(response, "rtt stop")
|
||||
log.info("RTT stopped")
|
||||
|
||||
async def channels(self) -> list[RTTChannel]:
|
||||
"""List available RTT channels.
|
||||
|
||||
Returns:
|
||||
List of RTTChannel descriptors (index, name, size, direction).
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the channels command fails.
|
||||
"""
|
||||
response = await self._conn.send("rtt channels")
|
||||
_check_rtt_response(response, "rtt channels")
|
||||
return _parse_channels(response)
|
||||
|
||||
async def read(self, channel: int) -> str:
|
||||
"""Read pending data from an RTT up-channel.
|
||||
|
||||
Args:
|
||||
channel: Channel index (typically 0 for Terminal).
|
||||
|
||||
Returns:
|
||||
The data read as a string (may be empty if nothing pending).
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the read command fails.
|
||||
"""
|
||||
cmd = f"rtt channelread {channel}"
|
||||
response = await self._conn.send(cmd)
|
||||
_check_rtt_response(response, cmd)
|
||||
return response
|
||||
|
||||
async def write(self, channel: int, data: str) -> None:
|
||||
"""Write data to an RTT down-channel.
|
||||
|
||||
Args:
|
||||
channel: Channel index (typically 0 for Terminal).
|
||||
data: String data to send to the target.
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the write command fails.
|
||||
"""
|
||||
cmd = f'rtt channelwrite {channel} "{data}"'
|
||||
response = await self._conn.send(cmd)
|
||||
_check_rtt_response(response, cmd)
|
||||
|
||||
|
||||
class SyncRTTManager:
|
||||
"""Synchronous wrapper around RTTManager."""
|
||||
|
||||
def __init__(self, manager: RTTManager, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._manager = manager
|
||||
self._loop = loop
|
||||
|
||||
def setup(
|
||||
self,
|
||||
address: int,
|
||||
size: int,
|
||||
id_string: str = "SEGGER RTT",
|
||||
) -> None:
|
||||
self._loop.run_until_complete(
|
||||
self._manager.setup(address, size, id_string)
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
self._loop.run_until_complete(self._manager.start())
|
||||
|
||||
def stop(self) -> None:
|
||||
self._loop.run_until_complete(self._manager.stop())
|
||||
|
||||
def channels(self) -> list[RTTChannel]:
|
||||
return self._loop.run_until_complete(self._manager.channels())
|
||||
|
||||
def read(self, channel: int) -> str:
|
||||
return self._loop.run_until_complete(self._manager.read(channel))
|
||||
|
||||
def write(self, channel: int, data: str) -> None:
|
||||
self._loop.run_until_complete(self._manager.write(channel, data))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _check_rtt_response(response: str, command: str) -> None:
|
||||
"""Raise on error responses from RTT commands."""
|
||||
if response and "error" in response.lower():
|
||||
raise OpenOCDError(f"RTT command failed ({command}): {response}")
|
||||
|
||||
|
||||
def _parse_channels(response: str) -> list[RTTChannel]:
|
||||
"""Parse the output of ``rtt channels`` into RTTChannel objects.
|
||||
|
||||
OpenOCD typically outputs lines like::
|
||||
|
||||
Up-channels:
|
||||
0: Terminal 1024
|
||||
Down-channels:
|
||||
0: Terminal 16
|
||||
|
||||
The exact format may vary by OpenOCD version; this parser is
|
||||
intentionally lenient.
|
||||
"""
|
||||
channels: list[RTTChannel] = []
|
||||
direction = "up"
|
||||
|
||||
for line in response.splitlines():
|
||||
stripped = line.strip()
|
||||
lower = stripped.lower()
|
||||
|
||||
if "up-channel" in lower or lower.startswith("up"):
|
||||
direction = "up"
|
||||
continue
|
||||
if "down-channel" in lower or lower.startswith("down"):
|
||||
direction = "down"
|
||||
continue
|
||||
|
||||
# Try to parse lines like "0: Terminal 1024"
|
||||
if ":" in stripped and stripped[0].isdigit():
|
||||
parts = stripped.split(":", 1)
|
||||
try:
|
||||
index = int(parts[0].strip())
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
rest = parts[1].strip().split()
|
||||
name = rest[0] if rest else f"channel_{index}"
|
||||
size = 0
|
||||
if len(rest) >= 2:
|
||||
with contextlib.suppress(ValueError):
|
||||
size = int(rest[-1])
|
||||
|
||||
channels.append(
|
||||
RTTChannel(
|
||||
index=index,
|
||||
name=name,
|
||||
size=size,
|
||||
direction=direction,
|
||||
)
|
||||
)
|
||||
|
||||
return channels
|
||||
301
src/openocd/session.py
Normal file
301
src/openocd/session.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""Session — the main entry point for openocd-python.
|
||||
|
||||
Manages the connection lifecycle and provides access to all subsystems
|
||||
(target, memory, registers, flash, JTAG, SVD, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openocd.connection.tcl_rpc import TclRpcConnection
|
||||
from openocd.process import OpenOCDProcess
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openocd.breakpoints import BreakpointManager, SyncBreakpointManager
|
||||
from openocd.flash import Flash, SyncFlash
|
||||
from openocd.jtag import JTAGController, SyncJTAGController
|
||||
from openocd.memory import Memory, SyncMemory
|
||||
from openocd.registers import Registers, SyncRegisters
|
||||
from openocd.rtt import RTTManager
|
||||
from openocd.svd import SVDManager, SyncSVDManager
|
||||
from openocd.target import SyncTarget, Target
|
||||
from openocd.transport import Transport
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Session:
|
||||
"""Main entry point. Manages connection and provides access to subsystems."""
|
||||
|
||||
def __init__(self, connection: TclRpcConnection, process: OpenOCDProcess | None = None) -> None:
|
||||
self._conn = connection
|
||||
self._process = process
|
||||
self._target: Target | None = None
|
||||
self._memory: Memory | None = None
|
||||
self._registers: Registers | None = None
|
||||
self._flash: Flash | None = None
|
||||
self._jtag: JTAGController | None = None
|
||||
self._breakpoints: BreakpointManager | None = None
|
||||
self._rtt: RTTManager | None = None
|
||||
self._svd: SVDManager | None = None
|
||||
self._transport: Transport | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Factory methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def start(
|
||||
cls,
|
||||
config: str | Path,
|
||||
*,
|
||||
tcl_port: int = 6666,
|
||||
openocd_bin: str | None = None,
|
||||
timeout: float = 10.0,
|
||||
extra_args: list[str] | None = None,
|
||||
) -> Session:
|
||||
"""Spawn an OpenOCD process and connect to it.
|
||||
|
||||
Args:
|
||||
config: Config file path or ``-f``/``-c`` flags string.
|
||||
tcl_port: TCL RPC port.
|
||||
openocd_bin: Custom OpenOCD binary path.
|
||||
timeout: Seconds to wait for OpenOCD readiness.
|
||||
extra_args: Additional CLI arguments for OpenOCD.
|
||||
"""
|
||||
proc = OpenOCDProcess()
|
||||
await proc.start(
|
||||
str(config), extra_args=extra_args, tcl_port=tcl_port, openocd_bin=openocd_bin
|
||||
)
|
||||
await proc.wait_ready(timeout=timeout)
|
||||
|
||||
conn = TclRpcConnection(timeout=timeout)
|
||||
await conn.connect("localhost", tcl_port)
|
||||
|
||||
return cls(connection=conn, process=proc)
|
||||
|
||||
@classmethod
|
||||
async def connect(
|
||||
cls,
|
||||
host: str = "localhost",
|
||||
port: int = 6666,
|
||||
timeout: float = 10.0,
|
||||
) -> Session:
|
||||
"""Connect to an already-running OpenOCD instance."""
|
||||
conn = TclRpcConnection(timeout=timeout)
|
||||
await conn.connect(host, port)
|
||||
return cls(connection=conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync factory wrappers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def start_sync(cls, config: str | Path, **kwargs) -> SyncSession:
|
||||
"""Synchronous version of start(). Returns a SyncSession."""
|
||||
loop = _get_or_create_loop()
|
||||
session = loop.run_until_complete(cls.start(config, **kwargs))
|
||||
return SyncSession(session, loop)
|
||||
|
||||
@classmethod
|
||||
def connect_sync(cls, host: str = "localhost", port: int = 6666, **kwargs) -> SyncSession:
|
||||
"""Synchronous version of connect(). Returns a SyncSession."""
|
||||
loop = _get_or_create_loop()
|
||||
session = loop.run_until_complete(cls.connect(host, port, **kwargs))
|
||||
return SyncSession(session, loop)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Context manager
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def __aenter__(self) -> Session:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the connection and stop the subprocess if we spawned it."""
|
||||
await self._conn.close()
|
||||
if self._process:
|
||||
await self._process.stop()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Raw command escape hatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def command(self, cmd: str) -> str:
|
||||
"""Send a raw OpenOCD command and return the response string."""
|
||||
return await self._conn.send(cmd)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subsystem accessors (lazy-initialized)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def target(self) -> Target:
|
||||
if self._target is None:
|
||||
from openocd.target import Target
|
||||
self._target = Target(self._conn)
|
||||
return self._target
|
||||
|
||||
@property
|
||||
def memory(self) -> Memory:
|
||||
if self._memory is None:
|
||||
from openocd.memory import Memory
|
||||
self._memory = Memory(self._conn)
|
||||
return self._memory
|
||||
|
||||
@property
|
||||
def registers(self) -> Registers:
|
||||
if self._registers is None:
|
||||
from openocd.registers import Registers
|
||||
self._registers = Registers(self._conn)
|
||||
return self._registers
|
||||
|
||||
@property
|
||||
def flash(self) -> Flash:
|
||||
if self._flash is None:
|
||||
from openocd.flash import Flash
|
||||
self._flash = Flash(self._conn)
|
||||
return self._flash
|
||||
|
||||
@property
|
||||
def jtag(self) -> JTAGController:
|
||||
if self._jtag is None:
|
||||
from openocd.jtag import JTAGController
|
||||
self._jtag = JTAGController(self._conn)
|
||||
return self._jtag
|
||||
|
||||
@property
|
||||
def breakpoints(self) -> BreakpointManager:
|
||||
if self._breakpoints is None:
|
||||
from openocd.breakpoints import BreakpointManager
|
||||
self._breakpoints = BreakpointManager(self._conn)
|
||||
return self._breakpoints
|
||||
|
||||
@property
|
||||
def rtt(self) -> RTTManager:
|
||||
if self._rtt is None:
|
||||
from openocd.rtt import RTTManager
|
||||
self._rtt = RTTManager(self._conn)
|
||||
return self._rtt
|
||||
|
||||
@property
|
||||
def svd(self) -> SVDManager:
|
||||
if self._svd is None:
|
||||
from openocd.svd import SVDManager
|
||||
self._svd = SVDManager(self._conn, self.memory)
|
||||
return self._svd
|
||||
|
||||
@property
|
||||
def transport(self) -> Transport:
|
||||
if self._transport is None:
|
||||
from openocd.transport import Transport
|
||||
self._transport = Transport(self._conn)
|
||||
return self._transport
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event shortcuts
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def on_halt(self, callback: Callable[[str], None]) -> None:
|
||||
"""Register a callback for target halt events."""
|
||||
def _filter(msg: str) -> None:
|
||||
if "halted" in msg.lower():
|
||||
callback(msg)
|
||||
self._conn.on_notification(_filter)
|
||||
|
||||
def on_reset(self, callback: Callable[[str], None]) -> None:
|
||||
"""Register a callback for target reset events."""
|
||||
def _filter(msg: str) -> None:
|
||||
if "reset" in msg.lower():
|
||||
callback(msg)
|
||||
self._conn.on_notification(_filter)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Sync wrapper
|
||||
# ======================================================================
|
||||
|
||||
class SyncSession:
|
||||
"""Wraps an async Session for synchronous use."""
|
||||
|
||||
def __init__(self, session: Session, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._session = session
|
||||
self._loop = loop
|
||||
|
||||
def __enter__(self) -> SyncSession:
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc) -> None:
|
||||
self._loop.run_until_complete(self._session.close())
|
||||
|
||||
def command(self, cmd: str) -> str:
|
||||
return self._loop.run_until_complete(self._session.command(cmd))
|
||||
|
||||
@property
|
||||
def target(self) -> SyncTarget:
|
||||
from openocd.target import SyncTarget
|
||||
return SyncTarget(self._session.target, self._loop)
|
||||
|
||||
@property
|
||||
def memory(self) -> SyncMemory:
|
||||
from openocd.memory import SyncMemory
|
||||
return SyncMemory(self._session.memory, self._loop)
|
||||
|
||||
@property
|
||||
def registers(self) -> SyncRegisters:
|
||||
from openocd.registers import SyncRegisters
|
||||
return SyncRegisters(self._session.registers, self._loop)
|
||||
|
||||
@property
|
||||
def flash(self) -> SyncFlash:
|
||||
from openocd.flash import SyncFlash
|
||||
return SyncFlash(self._session.flash, self._loop)
|
||||
|
||||
@property
|
||||
def jtag(self) -> SyncJTAGController:
|
||||
from openocd.jtag import SyncJTAGController
|
||||
return SyncJTAGController(self._session.jtag, self._loop)
|
||||
|
||||
@property
|
||||
def breakpoints(self) -> SyncBreakpointManager:
|
||||
from openocd.breakpoints import SyncBreakpointManager
|
||||
return SyncBreakpointManager(self._session.breakpoints, self._loop)
|
||||
|
||||
@property
|
||||
def svd(self) -> SyncSVDManager:
|
||||
from openocd.svd import SyncSVDManager
|
||||
return SyncSVDManager(self._session.svd, self._loop)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Helpers
|
||||
# ======================================================================
|
||||
|
||||
def _get_or_create_loop() -> asyncio.AbstractEventLoop:
|
||||
"""Get the running event loop, or create a new one if there isn't one."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're already in an async context we can't use run_until_complete
|
||||
raise RuntimeError(
|
||||
"Cannot use sync API from an async context. "
|
||||
"Use the async Session.start()/connect() instead."
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
5
src/openocd/svd/__init__.py
Normal file
5
src/openocd/svd/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""SVD (System View Description) integration for peripheral/register decoding."""
|
||||
|
||||
from openocd.svd.peripheral import SVDManager, SyncSVDManager
|
||||
|
||||
__all__ = ["SVDManager", "SyncSVDManager"]
|
||||
54
src/openocd/svd/decoder.py
Normal file
54
src/openocd/svd/decoder.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Register value decoding using SVD metadata.
|
||||
|
||||
Takes a raw integer read from hardware and splits it into named bitfields
|
||||
using the field definitions from a parsed SVD file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from openocd.types import BitField, DecodedRegister
|
||||
|
||||
|
||||
def decode_register(
|
||||
peripheral_obj: Any,
|
||||
register_obj: Any,
|
||||
raw_value: int,
|
||||
) -> DecodedRegister:
|
||||
"""Decode a raw register value into named bitfields using SVD metadata.
|
||||
|
||||
Args:
|
||||
peripheral_obj: cmsis_svd peripheral (used for base_address and name).
|
||||
register_obj: cmsis_svd register (used for fields, address_offset, name).
|
||||
raw_value: The 32-bit value read from hardware.
|
||||
|
||||
Returns:
|
||||
A DecodedRegister with all fields extracted and annotated.
|
||||
"""
|
||||
address = peripheral_obj.base_address + register_obj.address_offset
|
||||
fields: list[BitField] = []
|
||||
|
||||
for svd_field in register_obj.fields or []:
|
||||
mask = ((1 << svd_field.bit_width) - 1) << svd_field.bit_offset
|
||||
value = (raw_value & mask) >> svd_field.bit_offset
|
||||
fields.append(
|
||||
BitField(
|
||||
name=svd_field.name,
|
||||
offset=svd_field.bit_offset,
|
||||
width=svd_field.bit_width,
|
||||
value=value,
|
||||
description=svd_field.description or "",
|
||||
)
|
||||
)
|
||||
|
||||
# Sort fields by bit offset (low to high) for consistent display
|
||||
fields.sort(key=lambda f: f.offset)
|
||||
|
||||
return DecodedRegister(
|
||||
peripheral=peripheral_obj.name,
|
||||
register=register_obj.name,
|
||||
address=address,
|
||||
raw_value=raw_value,
|
||||
fields=fields,
|
||||
)
|
||||
128
src/openocd/svd/parser.py
Normal file
128
src/openocd/svd/parser.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""SVD file loading and peripheral/register lookup.
|
||||
|
||||
Wraps cmsis_svd to parse CMSIS-SVD XML files and provide indexed access
|
||||
to peripherals and their registers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from cmsis_svd import SVDParser
|
||||
|
||||
from openocd.errors import SVDError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SVDParserWrapper:
|
||||
"""Load and cache parsed SVD device data."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._device: Any = None
|
||||
self._peripherals: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
"""Whether an SVD file has been parsed."""
|
||||
return self._device is not None
|
||||
|
||||
def load(self, svd_path: Path) -> None:
|
||||
"""Parse an SVD file and index peripherals/registers.
|
||||
|
||||
Args:
|
||||
svd_path: Path to the .svd file on disk.
|
||||
|
||||
Raises:
|
||||
SVDError: If the file cannot be found or parsed.
|
||||
"""
|
||||
path = Path(svd_path)
|
||||
if not path.exists():
|
||||
raise SVDError(f"SVD file not found: {path}")
|
||||
|
||||
try:
|
||||
parser = SVDParser.for_xml_file(str(path))
|
||||
self._device = parser.get_device()
|
||||
except Exception as exc:
|
||||
raise SVDError(f"Failed to parse SVD file {path}: {exc}") from exc
|
||||
|
||||
self._peripherals = {p.name: p for p in self._device.get_peripherals()}
|
||||
log.info(
|
||||
"Loaded SVD for %s — %d peripherals",
|
||||
getattr(self._device, "name", "unknown"),
|
||||
len(self._peripherals),
|
||||
)
|
||||
|
||||
def _require_loaded(self) -> None:
|
||||
if not self.loaded:
|
||||
raise SVDError("No SVD file loaded — call load() first")
|
||||
|
||||
def get_peripheral(self, name: str) -> Any:
|
||||
"""Look up a peripheral by name.
|
||||
|
||||
Args:
|
||||
name: Peripheral name (case-sensitive, e.g. "GPIOA", "USART1").
|
||||
|
||||
Returns:
|
||||
The cmsis_svd peripheral object.
|
||||
|
||||
Raises:
|
||||
SVDError: If no SVD is loaded or the peripheral is not found.
|
||||
"""
|
||||
self._require_loaded()
|
||||
periph = self._peripherals.get(name)
|
||||
if periph is None:
|
||||
raise SVDError(
|
||||
f"Peripheral '{name}' not found. "
|
||||
f"Available: {', '.join(sorted(self._peripherals))}"
|
||||
)
|
||||
return periph
|
||||
|
||||
def get_register(self, peripheral: str, register: str) -> Any:
|
||||
"""Look up a register within a peripheral.
|
||||
|
||||
Args:
|
||||
peripheral: Peripheral name.
|
||||
register: Register name (e.g. "CR1", "SR").
|
||||
|
||||
Returns:
|
||||
The cmsis_svd register object.
|
||||
|
||||
Raises:
|
||||
SVDError: If the peripheral or register is not found.
|
||||
"""
|
||||
periph = self.get_peripheral(peripheral)
|
||||
registers = periph.registers or []
|
||||
for reg in registers:
|
||||
if reg.name == register:
|
||||
return reg
|
||||
|
||||
available = [r.name for r in registers]
|
||||
raise SVDError(
|
||||
f"Register '{register}' not found in {peripheral}. "
|
||||
f"Available: {', '.join(sorted(available))}"
|
||||
)
|
||||
|
||||
def list_peripherals(self) -> list[str]:
|
||||
"""Return sorted names of all peripherals in the SVD.
|
||||
|
||||
Raises:
|
||||
SVDError: If no SVD is loaded.
|
||||
"""
|
||||
self._require_loaded()
|
||||
return sorted(self._peripherals.keys())
|
||||
|
||||
def list_registers(self, peripheral: str) -> list[str]:
|
||||
"""Return sorted register names for a peripheral.
|
||||
|
||||
Args:
|
||||
peripheral: Peripheral name.
|
||||
|
||||
Raises:
|
||||
SVDError: If the peripheral is not found.
|
||||
"""
|
||||
periph = self.get_peripheral(peripheral)
|
||||
registers = periph.registers or []
|
||||
return sorted(r.name for r in registers)
|
||||
186
src/openocd/svd/peripheral.py
Normal file
186
src/openocd/svd/peripheral.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""SVDManager — combines SVD parsing, register decoding, and hardware reads.
|
||||
|
||||
This is the primary interface for SVD-based register inspection. It ties
|
||||
the SVD parser, bitfield decoder, and the Memory subsystem together so
|
||||
callers can do things like:
|
||||
|
||||
decoded = await svd.read_register("GPIOA", "ODR")
|
||||
print(decoded)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openocd.svd.decoder import decode_register
|
||||
from openocd.svd.parser import SVDParserWrapper
|
||||
from openocd.types import DecodedRegister
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openocd.connection.base import Connection
|
||||
from openocd.memory import Memory
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SVDManager:
|
||||
"""High-level SVD register access: parse, read, decode."""
|
||||
|
||||
def __init__(self, conn: Connection, memory: Memory) -> None:
|
||||
self._conn = conn
|
||||
self._memory = memory
|
||||
self._parser = SVDParserWrapper()
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
"""Whether an SVD file has been loaded."""
|
||||
return self._parser.loaded
|
||||
|
||||
async def load(self, svd_path: Path) -> None:
|
||||
"""Parse an SVD file and make its peripherals available.
|
||||
|
||||
This is a synchronous file parse wrapped in the async interface
|
||||
for consistency with the rest of the API.
|
||||
|
||||
Args:
|
||||
svd_path: Path to the .svd XML file.
|
||||
|
||||
Raises:
|
||||
SVDError: If the file is missing or unparseable.
|
||||
"""
|
||||
self._parser.load(svd_path)
|
||||
|
||||
def list_peripherals(self) -> list[str]:
|
||||
"""Return sorted peripheral names from the loaded SVD.
|
||||
|
||||
Raises:
|
||||
SVDError: If no SVD is loaded.
|
||||
"""
|
||||
return self._parser.list_peripherals()
|
||||
|
||||
def list_registers(self, peripheral: str) -> list[str]:
|
||||
"""Return sorted register names for a peripheral.
|
||||
|
||||
Args:
|
||||
peripheral: Peripheral name (e.g. "GPIOA").
|
||||
|
||||
Raises:
|
||||
SVDError: If no SVD is loaded or peripheral not found.
|
||||
"""
|
||||
return self._parser.list_registers(peripheral)
|
||||
|
||||
async def read_register(self, peripheral: str, register: str) -> DecodedRegister:
|
||||
"""Read a register from hardware and decode it using SVD metadata.
|
||||
|
||||
This is the primary method: it computes the register's memory-mapped
|
||||
address from the SVD, reads 32 bits from the target, and returns
|
||||
a fully decoded result with named bitfields.
|
||||
|
||||
Args:
|
||||
peripheral: Peripheral name (e.g. "GPIOA").
|
||||
register: Register name (e.g. "ODR").
|
||||
|
||||
Returns:
|
||||
DecodedRegister with address, raw value, and decoded fields.
|
||||
|
||||
Raises:
|
||||
SVDError: If peripheral/register not found.
|
||||
TargetError: If the memory read fails.
|
||||
"""
|
||||
periph_obj = self._parser.get_peripheral(peripheral)
|
||||
reg_obj = self._parser.get_register(peripheral, register)
|
||||
address = periph_obj.base_address + reg_obj.address_offset
|
||||
|
||||
values = await self._memory.read_u32(address)
|
||||
raw = values[0]
|
||||
return decode_register(periph_obj, reg_obj, raw)
|
||||
|
||||
async def read_peripheral(self, peripheral: str) -> dict[str, DecodedRegister]:
|
||||
"""Read and decode every register in a peripheral.
|
||||
|
||||
Args:
|
||||
peripheral: Peripheral name.
|
||||
|
||||
Returns:
|
||||
Dict mapping register name to its DecodedRegister.
|
||||
|
||||
Raises:
|
||||
SVDError: If peripheral not found.
|
||||
TargetError: If any memory read fails.
|
||||
"""
|
||||
periph_obj = self._parser.get_peripheral(peripheral)
|
||||
registers = periph_obj.registers or []
|
||||
result: dict[str, DecodedRegister] = {}
|
||||
|
||||
for reg_obj in registers:
|
||||
address = periph_obj.base_address + reg_obj.address_offset
|
||||
try:
|
||||
values = await self._memory.read_u32(address)
|
||||
raw = values[0]
|
||||
result[reg_obj.name] = decode_register(periph_obj, reg_obj, raw)
|
||||
except Exception as exc:
|
||||
log.warning(
|
||||
"Failed to read %s.%s @ 0x%08X: %s",
|
||||
peripheral,
|
||||
reg_obj.name,
|
||||
address,
|
||||
exc,
|
||||
)
|
||||
# Skip unreadable registers (write-only, reserved, etc.)
|
||||
|
||||
return result
|
||||
|
||||
def decode(self, peripheral: str, register: str, value: int) -> DecodedRegister:
|
||||
"""Decode a raw value without reading hardware.
|
||||
|
||||
Useful when you already have the register value (from a log,
|
||||
a previous read, or a known reset value).
|
||||
|
||||
Args:
|
||||
peripheral: Peripheral name.
|
||||
register: Register name.
|
||||
value: Raw 32-bit register value.
|
||||
|
||||
Returns:
|
||||
DecodedRegister with the decoded bitfields.
|
||||
"""
|
||||
periph_obj = self._parser.get_peripheral(peripheral)
|
||||
reg_obj = self._parser.get_register(peripheral, register)
|
||||
return decode_register(periph_obj, reg_obj, value)
|
||||
|
||||
|
||||
class SyncSVDManager:
|
||||
"""Synchronous wrapper around SVDManager."""
|
||||
|
||||
def __init__(self, manager: SVDManager, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._manager = manager
|
||||
self._loop = loop
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self._manager.loaded
|
||||
|
||||
def load(self, svd_path: Path) -> None:
|
||||
self._loop.run_until_complete(self._manager.load(svd_path))
|
||||
|
||||
def list_peripherals(self) -> list[str]:
|
||||
return self._manager.list_peripherals()
|
||||
|
||||
def list_registers(self, peripheral: str) -> list[str]:
|
||||
return self._manager.list_registers(peripheral)
|
||||
|
||||
def read_register(self, peripheral: str, register: str) -> DecodedRegister:
|
||||
return self._loop.run_until_complete(
|
||||
self._manager.read_register(peripheral, register)
|
||||
)
|
||||
|
||||
def read_peripheral(self, peripheral: str) -> dict[str, DecodedRegister]:
|
||||
return self._loop.run_until_complete(
|
||||
self._manager.read_peripheral(peripheral)
|
||||
)
|
||||
|
||||
def decode(self, peripheral: str, register: str, value: int) -> DecodedRegister:
|
||||
return self._manager.decode(peripheral, register, value)
|
||||
164
src/openocd/target.py
Normal file
164
src/openocd/target.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""Target state control — halt, resume, step, reset, and state queries.
|
||||
|
||||
Wraps the OpenOCD target commands: halt, resume, step, reset,
|
||||
wait_halt, and targets (for state inspection).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from openocd.connection.tcl_rpc import TclRpcConnection
|
||||
from openocd.errors import TargetError, TimeoutError
|
||||
from openocd.types import TargetState
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Matches a target row from "targets" output, e.g.:
|
||||
# " 0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted"
|
||||
_TARGET_ROW_RE = re.compile(
|
||||
r"^\s*\d+\*?\s+" # index, optional current marker
|
||||
r"(\S+)\s+" # target name
|
||||
r"\S+\s+" # type
|
||||
r"\S+\s+" # endian
|
||||
r"\S+\s+" # tap name
|
||||
r"(\S+)" # state
|
||||
)
|
||||
|
||||
|
||||
class Target:
|
||||
"""Target execution control — halt, resume, step, reset."""
|
||||
|
||||
def __init__(self, conn: TclRpcConnection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
async def halt(self) -> TargetState:
|
||||
"""Halt the target and return the resulting state."""
|
||||
resp = await self._conn.send("halt")
|
||||
if "error" in resp.lower() and "already" not in resp.lower():
|
||||
raise TargetError(f"halt failed: {resp}")
|
||||
return await self._parse_state()
|
||||
|
||||
async def resume(self, address: int | None = None) -> None:
|
||||
"""Resume execution, optionally from a specific address."""
|
||||
cmd = "resume"
|
||||
if address is not None:
|
||||
cmd = f"resume 0x{address:x}"
|
||||
resp = await self._conn.send(cmd)
|
||||
if "error" in resp.lower():
|
||||
raise TargetError(f"resume failed: {resp}")
|
||||
|
||||
async def step(self, address: int | None = None) -> TargetState:
|
||||
"""Single-step and return the resulting state."""
|
||||
cmd = "step"
|
||||
if address is not None:
|
||||
cmd = f"step 0x{address:x}"
|
||||
resp = await self._conn.send(cmd)
|
||||
if "error" in resp.lower():
|
||||
raise TargetError(f"step failed: {resp}")
|
||||
return await self._parse_state()
|
||||
|
||||
async def reset(self, mode: Literal["run", "halt", "init"] = "halt") -> None:
|
||||
"""Reset the target.
|
||||
|
||||
Args:
|
||||
mode: Reset mode — "run" resumes after reset, "halt" stops at
|
||||
the reset vector, "init" runs init scripts after reset.
|
||||
"""
|
||||
resp = await self._conn.send(f"reset {mode}")
|
||||
if "error" in resp.lower():
|
||||
raise TargetError(f"reset failed: {resp}")
|
||||
|
||||
async def wait_halt(self, timeout_ms: int = 5000) -> TargetState:
|
||||
"""Block until the target halts or the timeout expires.
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum wait time in milliseconds.
|
||||
|
||||
Raises:
|
||||
TimeoutError: Target did not halt within the deadline.
|
||||
"""
|
||||
resp = await self._conn.send(f"wait_halt {timeout_ms}")
|
||||
if "timed out" in resp.lower() or "time out" in resp.lower():
|
||||
raise TimeoutError(f"Target did not halt within {timeout_ms}ms")
|
||||
if "error" in resp.lower():
|
||||
raise TargetError(f"wait_halt failed: {resp}")
|
||||
return await self._parse_state()
|
||||
|
||||
async def state(self) -> TargetState:
|
||||
"""Query and return the current target state."""
|
||||
return await self._parse_state()
|
||||
|
||||
async def _parse_state(self) -> TargetState:
|
||||
"""Parse the ``targets`` command output into a TargetState.
|
||||
|
||||
The output looks like::
|
||||
|
||||
TargetName Type Endian TapName State
|
||||
-- ------------------ ---------- ------ ------------------ ------------
|
||||
0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted
|
||||
|
||||
If the target is halted, also reads the program counter via ``reg pc``.
|
||||
"""
|
||||
resp = await self._conn.send("targets")
|
||||
|
||||
name = "unknown"
|
||||
raw_state = "unknown"
|
||||
|
||||
for line in resp.splitlines():
|
||||
m = _TARGET_ROW_RE.match(line)
|
||||
if m:
|
||||
name = m.group(1)
|
||||
raw_state = m.group(2).lower()
|
||||
break
|
||||
|
||||
# Normalize to our known state literals
|
||||
if raw_state not in ("running", "halted", "reset", "debug-running"):
|
||||
raw_state = "unknown"
|
||||
|
||||
pc: int | None = None
|
||||
if raw_state == "halted":
|
||||
try:
|
||||
pc = await self._read_pc()
|
||||
except Exception:
|
||||
log.debug("Could not read PC while halted", exc_info=True)
|
||||
|
||||
return TargetState(name=name, state=raw_state, current_pc=pc)
|
||||
|
||||
async def _read_pc(self) -> int:
|
||||
"""Read the program counter from the halted target."""
|
||||
resp = await self._conn.send("reg pc")
|
||||
# Output: "pc (/32): 0x08001234"
|
||||
m = re.search(r":\s*(0x[0-9a-fA-F]+)", resp)
|
||||
if not m:
|
||||
raise TargetError(f"Cannot parse PC from: {resp}")
|
||||
return int(m.group(1), 16)
|
||||
|
||||
|
||||
class SyncTarget:
|
||||
"""Synchronous wrapper around Target."""
|
||||
|
||||
def __init__(self, target: Target, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._target = target
|
||||
self._loop = loop
|
||||
|
||||
def halt(self) -> TargetState:
|
||||
return self._loop.run_until_complete(self._target.halt())
|
||||
|
||||
def resume(self, address: int | None = None) -> None:
|
||||
self._loop.run_until_complete(self._target.resume(address))
|
||||
|
||||
def step(self, address: int | None = None) -> TargetState:
|
||||
return self._loop.run_until_complete(self._target.step(address))
|
||||
|
||||
def reset(self, mode: Literal["run", "halt", "init"] = "halt") -> None:
|
||||
self._loop.run_until_complete(self._target.reset(mode))
|
||||
|
||||
def wait_halt(self, timeout_ms: int = 5000) -> TargetState:
|
||||
return self._loop.run_until_complete(self._target.wait_halt(timeout_ms))
|
||||
|
||||
def state(self) -> TargetState:
|
||||
return self._loop.run_until_complete(self._target.state())
|
||||
149
src/openocd/transport.py
Normal file
149
src/openocd/transport.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Transport selection and debug adapter configuration.
|
||||
|
||||
OpenOCD supports multiple debug transports (JTAG, SWD, SWIM, etc.)
|
||||
and various adapter interfaces (CMSIS-DAP, ST-Link, J-Link, etc.).
|
||||
This module provides access to transport and adapter state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openocd.errors import OpenOCDError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openocd.connection.base import Connection
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Query and configure the debug transport and adapter.
|
||||
|
||||
Usage::
|
||||
|
||||
transport = Transport(conn)
|
||||
current = await transport.select() # e.g. "swd"
|
||||
available = await transport.list() # e.g. ["jtag", "swd"]
|
||||
speed = await transport.adapter_speed() # current kHz
|
||||
await transport.adapter_speed(4000) # set to 4 MHz
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection) -> None:
|
||||
self._conn = conn
|
||||
|
||||
async def select(self) -> str:
|
||||
"""Get the currently selected transport.
|
||||
|
||||
Returns:
|
||||
Transport name string (e.g. "jtag", "swd", "swim").
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the command fails.
|
||||
"""
|
||||
response = await self._conn.send("transport select")
|
||||
response = response.strip()
|
||||
if not response:
|
||||
raise OpenOCDError("Empty response from 'transport select'")
|
||||
return response
|
||||
|
||||
async def list(self) -> list[str]:
|
||||
"""List transports available for the current adapter.
|
||||
|
||||
Returns:
|
||||
List of transport name strings.
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the command fails.
|
||||
"""
|
||||
response = await self._conn.send("transport list")
|
||||
response = response.strip()
|
||||
if not response:
|
||||
raise OpenOCDError("Empty response from 'transport list'")
|
||||
|
||||
# OpenOCD may return a Tcl list like "jtag swd" or one per line
|
||||
transports: list[str] = []
|
||||
for line in response.splitlines():
|
||||
for token in line.split():
|
||||
cleaned = token.strip("{}")
|
||||
if cleaned:
|
||||
transports.append(cleaned)
|
||||
return transports
|
||||
|
||||
async def adapter_info(self) -> str:
|
||||
"""Get adapter/interface information.
|
||||
|
||||
Tries ``adapter name`` first (newer OpenOCD), falls back to
|
||||
``adapter info`` for older versions.
|
||||
|
||||
Returns:
|
||||
Adapter description string.
|
||||
"""
|
||||
# "adapter name" is the preferred command in OpenOCD >= 0.12
|
||||
response = await self._conn.send("adapter name")
|
||||
response = response.strip()
|
||||
|
||||
if not response or "invalid" in response.lower() or "error" in response.lower():
|
||||
response = await self._conn.send("adapter info")
|
||||
response = response.strip()
|
||||
|
||||
if not response:
|
||||
raise OpenOCDError("Could not determine adapter info")
|
||||
return response
|
||||
|
||||
async def adapter_speed(self, khz: int | None = None) -> int:
|
||||
"""Get or set the adapter clock speed.
|
||||
|
||||
Args:
|
||||
khz: If provided, set the adapter speed to this value in kHz.
|
||||
If None, just query the current speed.
|
||||
|
||||
Returns:
|
||||
The current (or newly set) adapter speed in kHz.
|
||||
|
||||
Raises:
|
||||
OpenOCDError: If the command fails or response is not parseable.
|
||||
"""
|
||||
cmd = f"adapter speed {khz}" if khz is not None else "adapter speed"
|
||||
|
||||
response = await self._conn.send(cmd)
|
||||
response = response.strip()
|
||||
|
||||
# OpenOCD response is typically just a number, or
|
||||
# "adapter speed: 4000 kHz" depending on the interface
|
||||
speed = _parse_speed(response)
|
||||
if speed is None:
|
||||
raise OpenOCDError(f"Cannot parse adapter speed from: {response!r}")
|
||||
|
||||
if khz is not None:
|
||||
log.info("Adapter speed set to %d kHz", speed)
|
||||
return speed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_speed(response: str) -> int | None:
|
||||
"""Extract a numeric kHz value from an adapter speed response.
|
||||
|
||||
Handles formats like:
|
||||
"4000"
|
||||
"adapter speed: 4000 kHz"
|
||||
"4000 kHz"
|
||||
"""
|
||||
# Try the whole thing as a plain integer
|
||||
try:
|
||||
return int(response)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Pull out the first integer-looking token
|
||||
for token in response.replace(":", " ").split():
|
||||
try:
|
||||
return int(token)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None
|
||||
185
src/openocd/types.py
Normal file
185
src/openocd/types.py
Normal file
@ -0,0 +1,185 @@
|
||||
"""Shared dataclasses and enums used across the openocd-python package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Target
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TargetState:
|
||||
"""Snapshot of target execution state."""
|
||||
|
||||
name: str
|
||||
state: Literal["running", "halted", "reset", "debug-running", "unknown"]
|
||||
current_pc: int | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Register:
|
||||
"""A single CPU register."""
|
||||
|
||||
name: str
|
||||
number: int
|
||||
value: int
|
||||
size: int # bits
|
||||
dirty: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Flash
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FlashSector:
|
||||
"""One sector inside a flash bank."""
|
||||
|
||||
index: int
|
||||
offset: int
|
||||
size: int
|
||||
protected: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FlashBank:
|
||||
"""A flash bank reported by OpenOCD."""
|
||||
|
||||
index: int
|
||||
name: str
|
||||
base: int
|
||||
size: int
|
||||
bus_width: int
|
||||
chip_width: int
|
||||
target: str
|
||||
sectors: list[FlashSector] = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JTAG
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TAPInfo:
|
||||
"""One TAP discovered on the JTAG chain."""
|
||||
|
||||
name: str
|
||||
chip: str
|
||||
tap_name: str
|
||||
idcode: int
|
||||
ir_length: int
|
||||
enabled: bool
|
||||
|
||||
|
||||
class JTAGState(str, Enum):
|
||||
"""IEEE 1149.1 TAP controller states."""
|
||||
|
||||
RESET = "RESET"
|
||||
IDLE = "IDLE"
|
||||
DRSELECT = "DRSELECT"
|
||||
DRCAPTURE = "DRCAPTURE"
|
||||
DRSHIFT = "DRSHIFT"
|
||||
DREXIT1 = "DREXIT1"
|
||||
DRPAUSE = "DRPAUSE"
|
||||
DREXIT2 = "DREXIT2"
|
||||
DRUPDATE = "DRUPDATE"
|
||||
IRSELECT = "IRSELECT"
|
||||
IRCAPTURE = "IRCAPTURE"
|
||||
IRSHIFT = "IRSHIFT"
|
||||
IREXIT1 = "IREXIT1"
|
||||
IRPAUSE = "IRPAUSE"
|
||||
IREXIT2 = "IREXIT2"
|
||||
IRUPDATE = "IRUPDATE"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryRegion:
|
||||
"""A chunk of memory read from the target."""
|
||||
|
||||
address: int
|
||||
size: int
|
||||
data: bytes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SVD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BitField:
|
||||
"""One decoded bitfield inside a register."""
|
||||
|
||||
name: str
|
||||
offset: int
|
||||
width: int
|
||||
value: int
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodedRegister:
|
||||
"""A register value decoded into named bitfields via SVD."""
|
||||
|
||||
peripheral: str
|
||||
register: str
|
||||
address: int
|
||||
raw_value: int
|
||||
fields: list[BitField] = field(default_factory=list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
header = f"{self.peripheral}.{self.register}"
|
||||
lines = [f"{header} @ 0x{self.address:08X} = 0x{self.raw_value:08X}"]
|
||||
for f in self.fields:
|
||||
bits = f"{f.offset + f.width - 1}:{f.offset}" if f.width > 1 else str(f.offset)
|
||||
lines.append(f" [{bits:>5s}] {f.name:<20s} = 0x{f.value:X} {f.description}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Breakpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Breakpoint:
|
||||
"""An active breakpoint."""
|
||||
|
||||
number: int
|
||||
type: Literal["hw", "sw"]
|
||||
address: int
|
||||
length: int
|
||||
enabled: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Watchpoint:
|
||||
"""An active watchpoint."""
|
||||
|
||||
number: int
|
||||
address: int
|
||||
length: int
|
||||
access: Literal["r", "w", "rw"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTT
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RTTChannel:
|
||||
"""An RTT channel descriptor."""
|
||||
|
||||
index: int
|
||||
name: str
|
||||
size: int
|
||||
direction: Literal["up", "down"]
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
37
tests/conftest.py
Normal file
37
tests/conftest.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Shared pytest fixtures for openocd-python tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from openocd.connection.tcl_rpc import TclRpcConnection
|
||||
from openocd.session import Session
|
||||
from tests.mock_server import MockOpenOCDServer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_ocd():
|
||||
"""Start a MockOpenOCDServer, yield (host, port), stop on teardown."""
|
||||
server = MockOpenOCDServer()
|
||||
await server.start()
|
||||
host, port = server.address
|
||||
yield host, port, server
|
||||
await server.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def connection(mock_ocd):
|
||||
"""A TclRpcConnection connected to the mock server."""
|
||||
host, port, _server = mock_ocd
|
||||
conn = TclRpcConnection(timeout=5.0)
|
||||
await conn.connect(host, port)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session(mock_ocd):
|
||||
"""A Session connected to the mock server via Session.connect()."""
|
||||
host, port, _server = mock_ocd
|
||||
sess = await Session.connect(host, port, timeout=5.0)
|
||||
yield sess
|
||||
await sess.close()
|
||||
255
tests/mock_server.py
Normal file
255
tests/mock_server.py
Normal file
@ -0,0 +1,255 @@
|
||||
"""Fake OpenOCD TCL RPC server for testing.
|
||||
|
||||
An asyncio TCP server that speaks the OpenOCD TCL RPC framing protocol:
|
||||
client sends: command_string + \\x1a
|
||||
server replies: response_string + \\x1a
|
||||
|
||||
Supports exact-match and regex-based command routing with pre-loaded
|
||||
responses that mirror real OpenOCD output.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
SEPARATOR = b"\x1a"
|
||||
|
||||
|
||||
# -- Canned OpenOCD responses ------------------------------------------------
|
||||
|
||||
TARGETS_RESPONSE = """\
|
||||
TargetName Type Endian TapName State
|
||||
-- ------------------ ---------- ------ ------------------ ------------
|
||||
0* stm32f1x.cpu cortex_m little stm32f1x.cpu halted"""
|
||||
|
||||
REG_PC_RESPONSE = "pc (/32): 0x08001234"
|
||||
REG_SP_RESPONSE = "sp (/32): 0x20005000"
|
||||
REG_LR_RESPONSE = "lr (/32): 0x08000100"
|
||||
REG_XPSR_RESPONSE = "xPSR (/32): 0x61000000"
|
||||
|
||||
REG_ALL_RESPONSE = """\
|
||||
===== ARM registers
|
||||
(0) r0 (/32): 0x00000000
|
||||
(1) r1 (/32): 0x00000001
|
||||
(2) r2 (/32): 0x20001000
|
||||
(3) r3 (/32): 0x00000003
|
||||
(4) r4 (/32): 0x00000000
|
||||
(5) r5 (/32): 0x00000000
|
||||
(6) r6 (/32): 0x00000000
|
||||
(7) r7 (/32): 0x20004FF0
|
||||
(8) r8 (/32): 0x00000000
|
||||
(9) r9 (/32): 0x00000000
|
||||
(10) r10 (/32): 0x00000000
|
||||
(11) r11 (/32): 0x00000000
|
||||
(12) r12 (/32): 0x00000000
|
||||
(13) sp (/32): 0x20005000
|
||||
(14) lr (/32): 0x08000100
|
||||
(15) pc (/32): 0x08001234
|
||||
(16) xPSR (/32): 0x61000000
|
||||
(17) msp (/32): 0x20005000
|
||||
(18) psp (/32): 0x00000000
|
||||
(19) primask (/1): 0x00
|
||||
(20) basepri (/8): 0x00
|
||||
(21) faultmask (/1): 0x00
|
||||
(22) control (/3): 0x00 (dirty)"""
|
||||
|
||||
READ_MEMORY_RESPONSE = "20005000 080001a1 080001ab 080001ad"
|
||||
|
||||
FLASH_BANKS_RESPONSE = (
|
||||
"#0 : stm32f1x.flash (stm32f1x) at 0x08000000,"
|
||||
" size 0x00020000, buswidth 0, chipwidth 0"
|
||||
)
|
||||
|
||||
SCAN_CHAIN_RESPONSE = """\
|
||||
TapName Enabled IdCode Expected IrLen IrCap IrMask
|
||||
-- ------------------- -------- ---------- ---------- ----- ----- ------
|
||||
0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f"""
|
||||
|
||||
BP_LIST_RESPONSE = """\
|
||||
Breakpoint(IVA): 0x08001234, 0x2, 1
|
||||
Breakpoint(IVA): 0x08001300, 0x2, 0"""
|
||||
|
||||
RTT_CHANNELS_RESPONSE = """\
|
||||
Up-channels:
|
||||
0: Terminal 1024
|
||||
1: Log 512
|
||||
Down-channels:
|
||||
0: Terminal 16"""
|
||||
|
||||
TRANSPORT_SELECT_RESPONSE = "swd"
|
||||
TRANSPORT_LIST_RESPONSE = "jtag swd"
|
||||
ADAPTER_SPEED_RESPONSE = "4000"
|
||||
|
||||
|
||||
def _build_default_responses() -> list[tuple[re.Pattern[str], str | Callable[[str], str]]]:
|
||||
"""Build the default command-to-response routing table.
|
||||
|
||||
Returns a list of (compiled_regex, response) pairs. The first match wins.
|
||||
Response can be a string or a callable that receives the full command.
|
||||
"""
|
||||
routes: list[tuple[re.Pattern[str], str | Callable[[str], str]]] = [
|
||||
# target state control
|
||||
(re.compile(r"^targets$"), TARGETS_RESPONSE),
|
||||
(re.compile(r"^halt$"), ""),
|
||||
(re.compile(r"^resume"), ""),
|
||||
(re.compile(r"^step"), ""),
|
||||
(re.compile(r"^reset\s+"), ""),
|
||||
(re.compile(r"^wait_halt"), ""),
|
||||
|
||||
# individual register reads (must come before bare "reg")
|
||||
(re.compile(r"^reg\s+pc$"), REG_PC_RESPONSE),
|
||||
(re.compile(r"^reg\s+sp$"), REG_SP_RESPONSE),
|
||||
(re.compile(r"^reg\s+lr$"), REG_LR_RESPONSE),
|
||||
(re.compile(r"^reg\s+xPSR$"), REG_XPSR_RESPONSE),
|
||||
# register write (reg <name> <value>)
|
||||
(re.compile(r"^reg\s+\S+\s+0x"), ""),
|
||||
# bare "reg" -> full listing
|
||||
(re.compile(r"^reg$"), REG_ALL_RESPONSE),
|
||||
|
||||
# memory
|
||||
(re.compile(r"^read_memory\s+0x8000000\s+32\s+4$"), READ_MEMORY_RESPONSE),
|
||||
# generic read_memory -- return zeros for widths/counts we haven't mapped
|
||||
(re.compile(r"^read_memory\s+"), _generic_read_memory),
|
||||
(re.compile(r"^write_memory\s+"), ""),
|
||||
|
||||
# flash
|
||||
(re.compile(r"^flash banks$"), FLASH_BANKS_RESPONSE),
|
||||
(re.compile(r"^flash\s+"), ""),
|
||||
|
||||
# JTAG
|
||||
(re.compile(r"^scan_chain$"), SCAN_CHAIN_RESPONSE),
|
||||
(re.compile(r"^irscan\s+"), "0x01"),
|
||||
(re.compile(r"^drscan\s+"), "0xDEADBEEF"),
|
||||
(re.compile(r"^runtest\s+"), ""),
|
||||
(re.compile(r"^pathmove\s+"), ""),
|
||||
|
||||
# breakpoints
|
||||
(re.compile(r"^bp\s+0x"), ""),
|
||||
(re.compile(r"^bp$"), BP_LIST_RESPONSE),
|
||||
(re.compile(r"^rbp\s+"), ""),
|
||||
(re.compile(r"^wp\s+0x"), ""),
|
||||
(re.compile(r"^wp$"), ""),
|
||||
(re.compile(r"^rwp\s+"), ""),
|
||||
|
||||
# transport / adapter
|
||||
(re.compile(r"^transport\s+select$"), TRANSPORT_SELECT_RESPONSE),
|
||||
(re.compile(r"^transport\s+list$"), TRANSPORT_LIST_RESPONSE),
|
||||
(re.compile(r"^adapter\s+speed$"), ADAPTER_SPEED_RESPONSE),
|
||||
(re.compile(r"^adapter\s+speed\s+\d+"), ADAPTER_SPEED_RESPONSE),
|
||||
(re.compile(r"^adapter\s+name$"), "cmsis-dap"),
|
||||
|
||||
# RTT
|
||||
(re.compile(r"^rtt\s+channels$"), RTT_CHANNELS_RESPONSE),
|
||||
(re.compile(r"^rtt\s+setup\s+"), ""),
|
||||
(re.compile(r"^rtt\s+start$"), ""),
|
||||
(re.compile(r"^rtt\s+stop$"), ""),
|
||||
(re.compile(r"^rtt\s+channelread\s+"), "hello from target"),
|
||||
(re.compile(r"^rtt\s+channelwrite\s+"), ""),
|
||||
|
||||
# notifications
|
||||
(re.compile(r"^tcl_notifications\s+"), ""),
|
||||
]
|
||||
return routes
|
||||
|
||||
|
||||
def _generic_read_memory(cmd: str) -> str:
|
||||
"""Generate a plausible response for an arbitrary read_memory command.
|
||||
|
||||
Parses the count from the command and returns that many hex zeros.
|
||||
"""
|
||||
parts = cmd.split()
|
||||
# read_memory <addr> <width> <count>
|
||||
count = 1
|
||||
if len(parts) >= 4:
|
||||
with contextlib.suppress(ValueError):
|
||||
count = int(parts[3])
|
||||
return " ".join(["00"] * count)
|
||||
|
||||
|
||||
class MockOpenOCDServer:
|
||||
"""Asyncio TCP server that fakes OpenOCD TCL RPC responses.
|
||||
|
||||
Usage::
|
||||
|
||||
server = MockOpenOCDServer()
|
||||
await server.start()
|
||||
host, port = server.address
|
||||
# ... connect and test ...
|
||||
await server.stop()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._server: asyncio.Server | None = None
|
||||
self._routes = _build_default_responses()
|
||||
self._host = "127.0.0.1"
|
||||
self._port = 0 # OS picks a free port
|
||||
# Track raw commands received, useful for assertions
|
||||
self.received_commands: list[str] = []
|
||||
|
||||
@property
|
||||
def address(self) -> tuple[str, int]:
|
||||
"""Return (host, port) the server is listening on."""
|
||||
if self._server is None:
|
||||
raise RuntimeError("Server not started")
|
||||
sock = self._server.sockets[0]
|
||||
return sock.getsockname()[:2]
|
||||
|
||||
def add_response(self, pattern: str, response: str | Callable[[str], str]) -> None:
|
||||
"""Prepend a custom response rule (takes priority over defaults)."""
|
||||
self._routes.insert(0, (re.compile(pattern), response))
|
||||
|
||||
async def start(self) -> None:
|
||||
self._server = await asyncio.start_server(
|
||||
self._handle_client, self._host, self._port
|
||||
)
|
||||
await self._server.start_serving()
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._server:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
self._server = None
|
||||
|
||||
async def _handle_client(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
"""Handle one client connection, reading commands and sending responses."""
|
||||
buf = bytearray()
|
||||
try:
|
||||
while True:
|
||||
chunk = await reader.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
buf.extend(chunk)
|
||||
|
||||
# Process all complete commands in the buffer
|
||||
while True:
|
||||
idx = buf.find(SEPARATOR)
|
||||
if idx == -1:
|
||||
break
|
||||
command = bytes(buf[:idx]).decode("utf-8", errors="replace")
|
||||
buf = buf[idx + 1 :]
|
||||
|
||||
self.received_commands.append(command)
|
||||
response = self._resolve(command)
|
||||
|
||||
writer.write(response.encode("utf-8") + SEPARATOR)
|
||||
await writer.drain()
|
||||
except (asyncio.CancelledError, ConnectionResetError, BrokenPipeError):
|
||||
pass
|
||||
finally:
|
||||
writer.close()
|
||||
with contextlib.suppress(OSError):
|
||||
await writer.wait_closed()
|
||||
|
||||
def _resolve(self, command: str) -> str:
|
||||
"""Find the first matching route and return its response."""
|
||||
for pattern, response in self._routes:
|
||||
if pattern.search(command):
|
||||
if callable(response):
|
||||
return response(command)
|
||||
return response
|
||||
# Unrecognized command returns empty (success)
|
||||
return ""
|
||||
113
tests/test_connection.py
Normal file
113
tests/test_connection.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Tests for the TclRpcConnection class."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from openocd.connection.tcl_rpc import TclRpcConnection
|
||||
from openocd.errors import ConnectionError, TimeoutError
|
||||
|
||||
|
||||
async def test_connect_to_mock_server(mock_ocd):
|
||||
"""Verify we can open a connection to the mock server."""
|
||||
host, port, _server = mock_ocd
|
||||
conn = TclRpcConnection(timeout=5.0)
|
||||
await conn.connect(host, port)
|
||||
assert conn._writer is not None
|
||||
assert conn._reader is not None
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def test_send_and_receive(connection):
|
||||
"""Send a command and verify we get the expected response."""
|
||||
resp = await connection.send("targets")
|
||||
assert "stm32f1x.cpu" in resp
|
||||
assert "halted" in resp
|
||||
|
||||
|
||||
async def test_separator_framing(mock_ocd):
|
||||
"""Verify the \\x1a framing works for multiple sequential commands."""
|
||||
host, port, _server = mock_ocd
|
||||
conn = TclRpcConnection(timeout=5.0)
|
||||
await conn.connect(host, port)
|
||||
|
||||
# Send several commands in sequence; each should get its own response
|
||||
resp1 = await conn.send("halt")
|
||||
resp2 = await conn.send("reg pc")
|
||||
resp3 = await conn.send("targets")
|
||||
|
||||
# halt returns empty
|
||||
assert resp1 == ""
|
||||
# reg pc returns a value
|
||||
assert "0x08001234" in resp2
|
||||
# targets returns the state table
|
||||
assert "stm32f1x.cpu" in resp3
|
||||
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def test_connection_error_no_server():
|
||||
"""Connecting to a port with no listener should raise ConnectionError."""
|
||||
conn = TclRpcConnection(timeout=1.0)
|
||||
with pytest.raises(ConnectionError):
|
||||
await conn.connect("127.0.0.1", 1) # port 1 is unlikely to be open
|
||||
|
||||
|
||||
async def test_send_before_connect_raises():
|
||||
"""Sending a command before connect() should raise ConnectionError."""
|
||||
conn = TclRpcConnection()
|
||||
with pytest.raises(ConnectionError, match="Not connected"):
|
||||
await conn.send("targets")
|
||||
|
||||
|
||||
async def test_timeout_on_hung_server():
|
||||
"""A server that never sends \\x1a should trigger a TimeoutError."""
|
||||
# Start a server that accepts connections but never responds
|
||||
async def _hang(reader, writer):
|
||||
# Read the command but never reply
|
||||
await reader.read(4096)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
server = await asyncio.start_server(_hang, "127.0.0.1", 0)
|
||||
await server.start_serving()
|
||||
sock = server.sockets[0]
|
||||
host, port = sock.getsockname()[:2]
|
||||
|
||||
conn = TclRpcConnection(timeout=0.3)
|
||||
await conn.connect(host, port)
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await conn.send("targets")
|
||||
|
||||
await conn.close()
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
|
||||
async def test_close_idempotent(connection):
|
||||
"""Calling close() multiple times should not raise."""
|
||||
await connection.close()
|
||||
await connection.close() # second call is a no-op
|
||||
|
||||
|
||||
async def test_concurrent_commands(mock_ocd):
|
||||
"""Multiple coroutines sharing one connection should serialize properly."""
|
||||
host, port, _server = mock_ocd
|
||||
conn = TclRpcConnection(timeout=5.0)
|
||||
await conn.connect(host, port)
|
||||
|
||||
async def _do_command(cmd: str) -> str:
|
||||
return await conn.send(cmd)
|
||||
|
||||
results = await asyncio.gather(
|
||||
_do_command("reg pc"),
|
||||
_do_command("reg sp"),
|
||||
_do_command("reg lr"),
|
||||
)
|
||||
|
||||
assert "0x08001234" in results[0]
|
||||
assert "0x20005000" in results[1]
|
||||
assert "0x08000100" in results[2]
|
||||
|
||||
await conn.close()
|
||||
93
tests/test_jtag.py
Normal file
93
tests/test_jtag.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Tests for the JTAG subsystem."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from openocd.types import TAPInfo
|
||||
|
||||
|
||||
async def test_scan_chain(session):
|
||||
"""scan_chain() should return a list of TAPInfo objects."""
|
||||
taps = await session.jtag.scan_chain()
|
||||
assert isinstance(taps, list)
|
||||
assert len(taps) == 1
|
||||
|
||||
|
||||
async def test_scan_chain_tap_fields(session):
|
||||
"""The returned TAPInfo should have all fields populated correctly."""
|
||||
taps = await session.jtag.scan_chain()
|
||||
tap = taps[0]
|
||||
assert isinstance(tap, TAPInfo)
|
||||
assert tap.name == "stm32f1x.cpu"
|
||||
assert tap.chip == "stm32f1x"
|
||||
assert tap.tap_name == "cpu"
|
||||
assert tap.idcode == 0x3BA00477
|
||||
assert tap.ir_length == 4
|
||||
assert tap.enabled is True
|
||||
|
||||
|
||||
async def test_scan_chain_frozen(session):
|
||||
"""TAPInfo should be immutable (frozen dataclass)."""
|
||||
taps = await session.jtag.scan_chain()
|
||||
tap = taps[0]
|
||||
with pytest.raises(AttributeError):
|
||||
tap.name = "something_else" # type: ignore[misc]
|
||||
|
||||
|
||||
async def test_irscan(session):
|
||||
"""irscan should return the shifted-out value as an int."""
|
||||
result = await session.jtag.irscan("stm32f1x.cpu", 0x0E)
|
||||
assert isinstance(result, int)
|
||||
assert result == 0x01
|
||||
|
||||
|
||||
async def test_drscan(session):
|
||||
"""drscan should return the shifted-out value as an int."""
|
||||
result = await session.jtag.drscan("stm32f1x.cpu", 32, 0x00000000)
|
||||
assert isinstance(result, int)
|
||||
assert result == 0xDEADBEEF
|
||||
|
||||
|
||||
async def test_runtest(session):
|
||||
"""runtest should complete without error."""
|
||||
await session.jtag.runtest(100)
|
||||
|
||||
|
||||
async def test_scan_chain_parsing_multiple_taps(mock_ocd):
|
||||
"""Verify the parser handles multiple TAPs in scan_chain output."""
|
||||
from openocd.jtag.chain import _parse_scan_chain
|
||||
|
||||
raw = """\
|
||||
TapName Enabled IdCode Expected IrLen IrCap IrMask
|
||||
-- ------------------- -------- ---------- ---------- ----- ----- ------
|
||||
0 stm32f1x.cpu Y 0x3ba00477 0x3ba00477 4 0x01 0x0f
|
||||
1 stm32f1x.bs N 0x06433041 0x06433041 5 0x01 0x1f"""
|
||||
|
||||
taps = _parse_scan_chain(raw)
|
||||
assert len(taps) == 2
|
||||
|
||||
assert taps[0].name == "stm32f1x.cpu"
|
||||
assert taps[0].enabled is True
|
||||
assert taps[0].idcode == 0x3BA00477
|
||||
assert taps[0].ir_length == 4
|
||||
|
||||
assert taps[1].name == "stm32f1x.bs"
|
||||
assert taps[1].chip == "stm32f1x"
|
||||
assert taps[1].tap_name == "bs"
|
||||
assert taps[1].enabled is False
|
||||
assert taps[1].idcode == 0x06433041
|
||||
assert taps[1].ir_length == 5
|
||||
|
||||
|
||||
def test_parse_scan_chain_empty():
|
||||
"""An empty scan_chain output should return an empty list."""
|
||||
from openocd.jtag.chain import _parse_scan_chain
|
||||
|
||||
result = _parse_scan_chain("")
|
||||
assert result == []
|
||||
|
||||
result = _parse_scan_chain(
|
||||
" TapName Enabled IdCode Expected IrLen IrCap IrMask\n"
|
||||
"-- ------------------- -------- ---------- ---------- ----- ----- ------\n"
|
||||
)
|
||||
assert result == []
|
||||
93
tests/test_memory.py
Normal file
93
tests/test_memory.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Tests for the Memory subsystem."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
async def test_read_u32(session):
|
||||
"""read_u32 should return a list[int] with correctly parsed hex values."""
|
||||
values = await session.memory.read_u32(0x8000000, 4)
|
||||
assert isinstance(values, list)
|
||||
assert len(values) == 4
|
||||
assert values[0] == 0x20005000
|
||||
assert values[1] == 0x080001A1
|
||||
assert values[2] == 0x080001AB
|
||||
assert values[3] == 0x080001AD
|
||||
|
||||
|
||||
async def test_read_u32_single(session):
|
||||
"""read_u32 with count=1 should return a single-element list."""
|
||||
values = await session.memory.read_u32(0x20000000, 1)
|
||||
assert isinstance(values, list)
|
||||
assert len(values) == 1
|
||||
|
||||
|
||||
async def test_read_u8(session):
|
||||
"""read_u8 should return a list of 8-bit values."""
|
||||
values = await session.memory.read_u8(0x20000000, 4)
|
||||
assert isinstance(values, list)
|
||||
assert len(values) == 4
|
||||
# Generic mock returns zeros for unregistered addresses
|
||||
assert all(isinstance(v, int) for v in values)
|
||||
|
||||
|
||||
async def test_read_u16(session):
|
||||
"""read_u16 should return a list of 16-bit values."""
|
||||
values = await session.memory.read_u16(0x20000000, 2)
|
||||
assert isinstance(values, list)
|
||||
assert len(values) == 2
|
||||
|
||||
|
||||
async def test_read_bytes(session):
|
||||
"""read_bytes should return a bytes object of the requested size."""
|
||||
data = await session.memory.read_bytes(0x20000000, 8)
|
||||
assert isinstance(data, bytes)
|
||||
assert len(data) == 8
|
||||
|
||||
|
||||
async def test_write_u32(session):
|
||||
"""write_u32 should complete without error."""
|
||||
await session.memory.write_u32(0x20000000, [0xDEADBEEF, 0xCAFEBABE])
|
||||
|
||||
|
||||
async def test_write_u32_single_value(session):
|
||||
"""write_u32 with a single int should complete without error."""
|
||||
await session.memory.write_u32(0x20000000, 0x12345678)
|
||||
|
||||
|
||||
async def test_write_bytes(session):
|
||||
"""write_bytes should complete without error."""
|
||||
await session.memory.write_bytes(0x20000000, b"\x00\x01\x02\x03")
|
||||
|
||||
|
||||
async def test_hexdump_format(session):
|
||||
"""hexdump should return a properly formatted hex+ASCII dump."""
|
||||
dump = await session.memory.hexdump(0x20000000, 32)
|
||||
assert isinstance(dump, str)
|
||||
lines = dump.strip().splitlines()
|
||||
assert len(lines) == 2 # 32 bytes / 16 bytes per line = 2 lines
|
||||
|
||||
# Each line should start with an address
|
||||
assert lines[0].startswith("20000000:")
|
||||
assert lines[1].startswith("20000010:")
|
||||
|
||||
# Each line should contain the ASCII column delimited by pipes
|
||||
for line in lines:
|
||||
assert "|" in line
|
||||
|
||||
|
||||
async def test_hexdump_ascii_column(session):
|
||||
"""Hexdump ASCII column should use dots for non-printable bytes."""
|
||||
dump = await session.memory.hexdump(0x20000000, 16)
|
||||
# The mock returns all zeros, which are non-printable
|
||||
assert "|" in dump
|
||||
# Extract the ASCII portion between the pipes
|
||||
ascii_part = dump.split("|")[1]
|
||||
# All-zero bytes map to dots
|
||||
assert all(c == "." for c in ascii_part)
|
||||
|
||||
|
||||
async def test_read_u32_returns_ints(session):
|
||||
"""All values from read_u32 should be Python ints."""
|
||||
values = await session.memory.read_u32(0x8000000, 4)
|
||||
for v in values:
|
||||
assert isinstance(v, int)
|
||||
assert v >= 0
|
||||
107
tests/test_registers.py
Normal file
107
tests/test_registers.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Tests for the Registers subsystem."""
|
||||
from __future__ import annotations
|
||||
|
||||
from openocd.types import Register
|
||||
|
||||
|
||||
async def test_read_pc(session):
|
||||
"""read('pc') should return the correct value from the mock."""
|
||||
val = await session.registers.read("pc")
|
||||
assert val == 0x08001234
|
||||
|
||||
|
||||
async def test_read_sp(session):
|
||||
"""read('sp') should return the correct value."""
|
||||
val = await session.registers.read("sp")
|
||||
assert val == 0x20005000
|
||||
|
||||
|
||||
async def test_read_lr(session):
|
||||
"""read('lr') should return the correct value."""
|
||||
val = await session.registers.read("lr")
|
||||
assert val == 0x08000100
|
||||
|
||||
|
||||
async def test_read_xpsr(session):
|
||||
"""read('xPSR') should return the correct value."""
|
||||
val = await session.registers.read("xPSR")
|
||||
assert val == 0x61000000
|
||||
|
||||
|
||||
async def test_read_all(session):
|
||||
"""read_all() should return a dict of Register objects keyed by name."""
|
||||
regs = await session.registers.read_all()
|
||||
assert isinstance(regs, dict)
|
||||
assert len(regs) > 0
|
||||
|
||||
# Spot-check a few registers
|
||||
assert "pc" in regs
|
||||
assert "sp" in regs
|
||||
assert "r0" in regs
|
||||
assert "xPSR" in regs
|
||||
|
||||
|
||||
async def test_read_all_register_type(session):
|
||||
"""Each value in read_all() should be a Register dataclass."""
|
||||
regs = await session.registers.read_all()
|
||||
for name, reg in regs.items():
|
||||
assert isinstance(reg, Register)
|
||||
assert reg.name == name
|
||||
assert isinstance(reg.number, int)
|
||||
assert isinstance(reg.value, int)
|
||||
assert isinstance(reg.size, int)
|
||||
assert isinstance(reg.dirty, bool)
|
||||
|
||||
|
||||
async def test_read_all_pc_value(session):
|
||||
"""The pc register from read_all() should have the correct value."""
|
||||
regs = await session.registers.read_all()
|
||||
pc = regs["pc"]
|
||||
assert pc.value == 0x08001234
|
||||
assert pc.size == 32
|
||||
assert pc.number == 15
|
||||
|
||||
|
||||
async def test_read_all_dirty_flag(session):
|
||||
"""The control register should have dirty=True in our mock data."""
|
||||
regs = await session.registers.read_all()
|
||||
control = regs["control"]
|
||||
assert control.dirty is True
|
||||
|
||||
|
||||
async def test_convenience_pc(session):
|
||||
"""The pc() convenience method should match read('pc')."""
|
||||
val = await session.registers.pc()
|
||||
assert val == 0x08001234
|
||||
|
||||
|
||||
async def test_convenience_sp(session):
|
||||
"""The sp() convenience method should match read('sp')."""
|
||||
val = await session.registers.sp()
|
||||
assert val == 0x20005000
|
||||
|
||||
|
||||
async def test_convenience_lr(session):
|
||||
"""The lr() convenience method should match read('lr')."""
|
||||
val = await session.registers.lr()
|
||||
assert val == 0x08000100
|
||||
|
||||
|
||||
async def test_convenience_xpsr(session):
|
||||
"""The xpsr() convenience method should match read('xPSR')."""
|
||||
val = await session.registers.xpsr()
|
||||
assert val == 0x61000000
|
||||
|
||||
|
||||
async def test_write(session):
|
||||
"""write() should complete without error."""
|
||||
await session.registers.write("r0", 0xDEADBEEF)
|
||||
|
||||
|
||||
async def test_read_many(session):
|
||||
"""read_many() should return values for all requested registers."""
|
||||
results = await session.registers.read_many(["pc", "sp", "lr"])
|
||||
assert len(results) == 3
|
||||
assert results["pc"] == 0x08001234
|
||||
assert results["sp"] == 0x20005000
|
||||
assert results["lr"] == 0x08000100
|
||||
98
tests/test_session.py
Normal file
98
tests/test_session.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Tests for the Session class."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from openocd.breakpoints import BreakpointManager
|
||||
from openocd.flash import Flash
|
||||
from openocd.jtag import JTAGController
|
||||
from openocd.memory import Memory
|
||||
from openocd.registers import Registers
|
||||
from openocd.rtt import RTTManager
|
||||
from openocd.session import Session
|
||||
from openocd.svd import SVDManager
|
||||
from openocd.target import Target
|
||||
from openocd.transport import Transport
|
||||
|
||||
|
||||
async def test_connect_to_mock(mock_ocd):
|
||||
"""Session.connect() should successfully connect to the mock server."""
|
||||
host, port, _server = mock_ocd
|
||||
sess = await Session.connect(host, port, timeout=5.0)
|
||||
assert sess is not None
|
||||
await sess.close()
|
||||
|
||||
|
||||
async def test_raw_command(session):
|
||||
"""session.command() should pass through to the underlying connection."""
|
||||
resp = await session.command("targets")
|
||||
assert "stm32f1x.cpu" in resp
|
||||
|
||||
|
||||
async def test_context_manager(mock_ocd):
|
||||
"""Session should work as an async context manager."""
|
||||
host, port, _server = mock_ocd
|
||||
async with await Session.connect(host, port, timeout=5.0) as sess:
|
||||
resp = await sess.command("halt")
|
||||
assert resp == ""
|
||||
# After exiting the context, the connection is closed.
|
||||
# Attempting to send should raise.
|
||||
from openocd.errors import ConnectionError
|
||||
with pytest.raises(ConnectionError):
|
||||
await sess.command("targets")
|
||||
|
||||
|
||||
async def test_subsystem_target_type(session):
|
||||
"""session.target should return a Target instance."""
|
||||
assert isinstance(session.target, Target)
|
||||
|
||||
|
||||
async def test_subsystem_memory_type(session):
|
||||
"""session.memory should return a Memory instance."""
|
||||
assert isinstance(session.memory, Memory)
|
||||
|
||||
|
||||
async def test_subsystem_registers_type(session):
|
||||
"""session.registers should return a Registers instance."""
|
||||
assert isinstance(session.registers, Registers)
|
||||
|
||||
|
||||
async def test_subsystem_flash_type(session):
|
||||
"""session.flash should return a Flash instance."""
|
||||
assert isinstance(session.flash, Flash)
|
||||
|
||||
|
||||
async def test_subsystem_jtag_type(session):
|
||||
"""session.jtag should return a JTAGController instance."""
|
||||
assert isinstance(session.jtag, JTAGController)
|
||||
|
||||
|
||||
async def test_subsystem_breakpoints_type(session):
|
||||
"""session.breakpoints should return a BreakpointManager instance."""
|
||||
assert isinstance(session.breakpoints, BreakpointManager)
|
||||
|
||||
|
||||
async def test_subsystem_rtt_type(session):
|
||||
"""session.rtt should return an RTTManager instance."""
|
||||
assert isinstance(session.rtt, RTTManager)
|
||||
|
||||
|
||||
async def test_subsystem_svd_type(session):
|
||||
"""session.svd should return an SVDManager instance."""
|
||||
assert isinstance(session.svd, SVDManager)
|
||||
|
||||
|
||||
async def test_subsystem_transport_type(session):
|
||||
"""session.transport should return a Transport instance."""
|
||||
assert isinstance(session.transport, Transport)
|
||||
|
||||
|
||||
async def test_subsystem_lazy_initialization(session):
|
||||
"""Accessing the same property twice should return the same object."""
|
||||
t1 = session.target
|
||||
t2 = session.target
|
||||
assert t1 is t2
|
||||
|
||||
m1 = session.memory
|
||||
m2 = session.memory
|
||||
assert m1 is m2
|
||||
249
tests/test_svd.py
Normal file
249
tests/test_svd.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""Tests for SVD decoding (no hardware required).
|
||||
|
||||
These tests exercise the bitfield decoder and DecodedRegister formatting
|
||||
using synthetic data, without needing an SVD file or a mock server.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from openocd.svd.decoder import decode_register
|
||||
from openocd.types import BitField, DecodedRegister
|
||||
|
||||
# -- Fake SVD objects to avoid needing a real .svd file -----------------------
|
||||
|
||||
@dataclass
|
||||
class FakeSVDField:
|
||||
name: str
|
||||
bit_offset: int
|
||||
bit_width: int
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeSVDRegister:
|
||||
name: str
|
||||
address_offset: int
|
||||
fields: list[FakeSVDField]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeSVDPeripheral:
|
||||
name: str
|
||||
base_address: int
|
||||
registers: list[FakeSVDRegister]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gpioa_odr():
|
||||
"""A fake GPIOA.ODR register with two bitfields."""
|
||||
fields = [
|
||||
FakeSVDField(
|
||||
name="ODR0", bit_offset=0, bit_width=1,
|
||||
description="Port output data bit 0",
|
||||
),
|
||||
FakeSVDField(
|
||||
name="ODR1", bit_offset=1, bit_width=1,
|
||||
description="Port output data bit 1",
|
||||
),
|
||||
FakeSVDField(
|
||||
name="ODR15_2", bit_offset=2, bit_width=14,
|
||||
description="Port output data bits 15:2",
|
||||
),
|
||||
]
|
||||
register = FakeSVDRegister(name="ODR", address_offset=0x14, fields=fields)
|
||||
peripheral = FakeSVDPeripheral(
|
||||
name="GPIOA", base_address=0x40010800, registers=[register]
|
||||
)
|
||||
return peripheral, register
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def usart_cr1():
|
||||
"""A fake USART1.CR1 register with multiple bitfields."""
|
||||
fields = [
|
||||
FakeSVDField(name="UE", bit_offset=0, bit_width=1, description="USART enable"),
|
||||
FakeSVDField(name="RE", bit_offset=2, bit_width=1, description="Receiver enable"),
|
||||
FakeSVDField(name="TE", bit_offset=3, bit_width=1, description="Transmitter enable"),
|
||||
FakeSVDField(
|
||||
name="RXNEIE", bit_offset=5, bit_width=1,
|
||||
description="RXNE interrupt enable",
|
||||
),
|
||||
FakeSVDField(
|
||||
name="TCIE", bit_offset=6, bit_width=1,
|
||||
description="Transmission complete IE",
|
||||
),
|
||||
FakeSVDField(
|
||||
name="TXEIE", bit_offset=7, bit_width=1,
|
||||
description="TXE interrupt enable",
|
||||
),
|
||||
FakeSVDField(name="M", bit_offset=12, bit_width=1, description="Word length"),
|
||||
FakeSVDField(
|
||||
name="OVER8", bit_offset=15, bit_width=1,
|
||||
description="Oversampling mode",
|
||||
),
|
||||
]
|
||||
register = FakeSVDRegister(name="CR1", address_offset=0x0C, fields=fields)
|
||||
peripheral = FakeSVDPeripheral(
|
||||
name="USART1", base_address=0x40013800, registers=[register]
|
||||
)
|
||||
return peripheral, register
|
||||
|
||||
|
||||
def test_decode_register_basic(gpioa_odr):
|
||||
"""decode_register should extract bitfield values from a raw integer."""
|
||||
peripheral, register = gpioa_odr
|
||||
# Set ODR0=1, ODR1=0, ODR15_2 = 0x0005 (bits 2..15 = 0b00000000010100 = 0x14 shifted)
|
||||
raw = 0b0000000000010101 # ODR0=1, ODR1=0, ODR15_2=0x0005
|
||||
decoded = decode_register(peripheral, register, raw)
|
||||
|
||||
assert isinstance(decoded, DecodedRegister)
|
||||
assert decoded.peripheral == "GPIOA"
|
||||
assert decoded.register == "ODR"
|
||||
assert decoded.address == 0x40010800 + 0x14
|
||||
assert decoded.raw_value == raw
|
||||
|
||||
|
||||
def test_decode_bitfield_extraction(gpioa_odr):
|
||||
"""Individual bitfield values should be correctly masked and shifted."""
|
||||
peripheral, register = gpioa_odr
|
||||
raw = 0b0000000000010101
|
||||
decoded = decode_register(peripheral, register, raw)
|
||||
|
||||
fields_by_name = {f.name: f for f in decoded.fields}
|
||||
|
||||
assert fields_by_name["ODR0"].value == 1
|
||||
assert fields_by_name["ODR1"].value == 0
|
||||
assert fields_by_name["ODR15_2"].value == 0x0005
|
||||
|
||||
|
||||
def test_decode_all_ones(gpioa_odr):
|
||||
"""All-ones value should set all fields to their max values."""
|
||||
peripheral, register = gpioa_odr
|
||||
raw = 0xFFFF
|
||||
decoded = decode_register(peripheral, register, raw)
|
||||
|
||||
fields_by_name = {f.name: f for f in decoded.fields}
|
||||
|
||||
assert fields_by_name["ODR0"].value == 1
|
||||
assert fields_by_name["ODR1"].value == 1
|
||||
assert fields_by_name["ODR15_2"].value == (1 << 14) - 1 # 0x3FFF
|
||||
|
||||
|
||||
def test_decode_all_zeros(gpioa_odr):
|
||||
"""All-zeros value should yield all-zero fields."""
|
||||
peripheral, register = gpioa_odr
|
||||
decoded = decode_register(peripheral, register, 0x0000)
|
||||
|
||||
for field in decoded.fields:
|
||||
assert field.value == 0
|
||||
|
||||
|
||||
def test_bitfield_type(gpioa_odr):
|
||||
"""Each field in a DecodedRegister should be a BitField dataclass."""
|
||||
peripheral, register = gpioa_odr
|
||||
decoded = decode_register(peripheral, register, 0xAAAA)
|
||||
|
||||
for field in decoded.fields:
|
||||
assert isinstance(field, BitField)
|
||||
assert isinstance(field.name, str)
|
||||
assert isinstance(field.offset, int)
|
||||
assert isinstance(field.width, int)
|
||||
assert isinstance(field.value, int)
|
||||
assert isinstance(field.description, str)
|
||||
|
||||
|
||||
def test_fields_sorted_by_offset(gpioa_odr):
|
||||
"""Decoded fields should be sorted by bit offset (low to high)."""
|
||||
peripheral, register = gpioa_odr
|
||||
decoded = decode_register(peripheral, register, 0x1234)
|
||||
|
||||
offsets = [f.offset for f in decoded.fields]
|
||||
assert offsets == sorted(offsets)
|
||||
|
||||
|
||||
def test_decoded_register_str(gpioa_odr):
|
||||
"""__str__ should produce a multi-line representation with field details."""
|
||||
peripheral, register = gpioa_odr
|
||||
raw = 0b0000000000010101
|
||||
decoded = decode_register(peripheral, register, raw)
|
||||
|
||||
text = str(decoded)
|
||||
assert "GPIOA.ODR" in text
|
||||
assert "0X40010814" in text.upper()
|
||||
assert "ODR0" in text
|
||||
assert "ODR1" in text
|
||||
assert "ODR15_2" in text
|
||||
|
||||
|
||||
def test_decoded_register_str_shows_values(gpioa_odr):
|
||||
"""The string representation should show each field's hex value."""
|
||||
peripheral, register = gpioa_odr
|
||||
decoded = decode_register(peripheral, register, 0x0001)
|
||||
|
||||
text = str(decoded)
|
||||
# ODR0 = 1 should appear as "0x1"
|
||||
assert "0x1" in text
|
||||
|
||||
|
||||
def test_decode_complex_register(usart_cr1):
|
||||
"""Decode a multi-field register and verify specific field values."""
|
||||
peripheral, register = usart_cr1
|
||||
# UE=1, RE=1, TE=1 -> bits 0,2,3 set -> raw = 0x000D
|
||||
raw = 0x000D
|
||||
decoded = decode_register(peripheral, register, raw)
|
||||
|
||||
fields_by_name = {f.name: f for f in decoded.fields}
|
||||
|
||||
assert fields_by_name["UE"].value == 1
|
||||
assert fields_by_name["RE"].value == 1
|
||||
assert fields_by_name["TE"].value == 1
|
||||
assert fields_by_name["RXNEIE"].value == 0
|
||||
assert fields_by_name["M"].value == 0
|
||||
assert fields_by_name["OVER8"].value == 0
|
||||
|
||||
|
||||
def test_decode_address_calculation(usart_cr1):
|
||||
"""The decoded address should be base + offset."""
|
||||
peripheral, register = usart_cr1
|
||||
decoded = decode_register(peripheral, register, 0)
|
||||
assert decoded.address == 0x40013800 + 0x0C
|
||||
|
||||
|
||||
def test_decoded_register_fields_list(gpioa_odr):
|
||||
"""fields should be a plain list, not some other iterable."""
|
||||
peripheral, register = gpioa_odr
|
||||
decoded = decode_register(peripheral, register, 0)
|
||||
assert isinstance(decoded.fields, list)
|
||||
assert len(decoded.fields) == 3
|
||||
|
||||
|
||||
def test_bitfield_frozen():
|
||||
"""BitField should be immutable (frozen dataclass)."""
|
||||
bf = BitField(name="TEST", offset=0, width=1, value=1, description="test")
|
||||
with pytest.raises(AttributeError):
|
||||
bf.value = 2 # type: ignore[misc]
|
||||
|
||||
|
||||
def test_decoded_register_str_single_bit_range(gpioa_odr):
|
||||
"""Single-bit fields should show just the bit number, not a range."""
|
||||
peripheral, register = gpioa_odr
|
||||
decoded = decode_register(peripheral, register, 0x0001)
|
||||
text = str(decoded)
|
||||
# ODR0 is at offset 0, width 1 -> should show "[ 0]" not "[0:0]"
|
||||
lines = text.strip().splitlines()
|
||||
# Find the ODR0 line
|
||||
odr0_line = [ln for ln in lines if "ODR0 " in ln or "ODR0" in ln.split()][0]
|
||||
assert "0]" in odr0_line
|
||||
|
||||
|
||||
def test_decoded_register_str_multi_bit_range(gpioa_odr):
|
||||
"""Multi-bit fields should show a bit range like [15:2]."""
|
||||
peripheral, register = gpioa_odr
|
||||
decoded = decode_register(peripheral, register, 0xFFFF)
|
||||
text = str(decoded)
|
||||
lines = text.strip().splitlines()
|
||||
odr15_2_line = [ln for ln in lines if "ODR15_2" in ln][0]
|
||||
assert "15:2" in odr15_2_line
|
||||
74
tests/test_target.py
Normal file
74
tests/test_target.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""Tests for the Target subsystem."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from openocd.types import TargetState
|
||||
|
||||
|
||||
async def test_state_returns_target_state(session):
|
||||
"""target.state() should return a TargetState with correct fields."""
|
||||
state = await session.target.state()
|
||||
assert isinstance(state, TargetState)
|
||||
assert state.name == "stm32f1x.cpu"
|
||||
assert state.state == "halted"
|
||||
# When halted, the mock returns pc = 0x08001234
|
||||
assert state.current_pc == 0x08001234
|
||||
|
||||
|
||||
async def test_halt(session):
|
||||
"""target.halt() should return a TargetState."""
|
||||
state = await session.target.halt()
|
||||
assert isinstance(state, TargetState)
|
||||
assert state.state == "halted"
|
||||
|
||||
|
||||
async def test_resume(session):
|
||||
"""target.resume() should complete without error."""
|
||||
await session.target.resume()
|
||||
|
||||
|
||||
async def test_resume_with_address(session):
|
||||
"""target.resume(address=...) should complete without error."""
|
||||
await session.target.resume(address=0x08000000)
|
||||
|
||||
|
||||
async def test_step(session):
|
||||
"""target.step() should return a TargetState."""
|
||||
state = await session.target.step()
|
||||
assert isinstance(state, TargetState)
|
||||
|
||||
|
||||
async def test_step_with_address(session):
|
||||
"""target.step(address=...) should complete without error."""
|
||||
state = await session.target.step(address=0x08001234)
|
||||
assert isinstance(state, TargetState)
|
||||
|
||||
|
||||
async def test_reset_halt(session):
|
||||
"""target.reset('halt') should complete without error."""
|
||||
await session.target.reset("halt")
|
||||
|
||||
|
||||
async def test_reset_run(session):
|
||||
"""target.reset('run') should complete without error."""
|
||||
await session.target.reset("run")
|
||||
|
||||
|
||||
async def test_reset_init(session):
|
||||
"""target.reset('init') should complete without error."""
|
||||
await session.target.reset("init")
|
||||
|
||||
|
||||
async def test_state_pc_field(session):
|
||||
"""When halted, current_pc should be populated from reg pc."""
|
||||
state = await session.target.state()
|
||||
assert state.current_pc is not None
|
||||
assert state.current_pc == 0x08001234
|
||||
|
||||
|
||||
async def test_state_frozen_dataclass(session):
|
||||
"""TargetState should be immutable (frozen dataclass)."""
|
||||
state = await session.target.state()
|
||||
with pytest.raises(AttributeError):
|
||||
state.name = "something_else" # type: ignore[misc]
|
||||
Loading…
x
Reference in New Issue
Block a user