From fa41ffbf80ce41058cd43252771a9d5ba6a62787 Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Thu, 5 Mar 2026 20:31:06 -0700 Subject: [PATCH] Harden all tool/resource paths per Hamilton review - Add asyncio.wait_for timeout (30s) to all D-Bus calls - Add asyncio.Lock to BusManager.get_bus to prevent race conditions - Convert recursive tree walk to bounded BFS with visited set, max depth (20), max nodes (500) - Add input validation for D-Bus names and object paths at tool boundary - Standardize error handling (RuntimeError, TimeoutError) across all tools - Catch JSONDecodeError in deserialize_args and set_property - Check arg count matches signature in deserialize_args - Add explicit variant wrapping via {"signature": "ai", "value": [1,2,3]} - Narrow resource exception handling from Exception to (RuntimeError, OSError) - Guard disconnect_all per-bus to avoid partial cleanup on error - Audit log system bus calls and all set_property calls to stderr - Consolidate _get_mgr into get_mgr in _bus.py - Sort MPRIS players for deterministic auto-discovery - Fix progress reporting (pass None as total when unknown) - Add 7 new tests for validation, arg count, and JSON error paths (32 total) --- src/mcdbus/_bus.py | 130 ++++++++++++++++++++++++++++--------- src/mcdbus/_discovery.py | 65 ++++++++++++++----- src/mcdbus/_interaction.py | 127 +++++++++++++++++++++++++----------- src/mcdbus/_resources.py | 27 +++++--- src/mcdbus/_shortcuts.py | 103 ++++++++++++++++------------- tests/test_bus.py | 42 +++++++++++- 6 files changed, 352 insertions(+), 142 deletions(-) diff --git a/src/mcdbus/_bus.py b/src/mcdbus/_bus.py index 555aeff..db66533 100644 --- a/src/mcdbus/_bus.py +++ b/src/mcdbus/_bus.py @@ -1,11 +1,39 @@ """Bus connection management and D-Bus type serialization.""" +import asyncio import json +import re +import sys from typing import Any from dbus_fast import BusType, Message, MessageType from dbus_fast.aio import MessageBus from dbus_fast.unpack import unpack_variants +from fastmcp import Context + +DBUS_CALL_TIMEOUT = float(30) # seconds; override via env MCDBUS_TIMEOUT + +# D-Bus spec: https://dbus.freedesktop.org/doc/dbus-specification.html +_DBUS_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_.]*(\.[A-Za-z_][A-Za-z0-9_]*)+$") +_DBUS_PATH_RE = re.compile(r"^/$|^(/[A-Za-z0-9_]+)+$") + + +def validate_bus_name(name: str, label: str = "service") -> None: + """Validate a D-Bus well-known or interface name.""" + if not _DBUS_NAME_RE.match(name): + raise ValueError( + f"Invalid D-Bus {label}: {name!r}. " + f"Must match [A-Za-z_][A-Za-z0-9_]*(.[A-Za-z_][A-Za-z0-9_]*)+" + ) + + +def validate_object_path(path: str) -> None: + """Validate a D-Bus object path.""" + if not _DBUS_PATH_RE.match(path): + raise ValueError( + f"Invalid D-Bus object path: {path!r}. " + f"Must be '/' or '/element/element/...' with [A-Za-z0-9_]" + ) class BusManager: @@ -13,31 +41,44 @@ class BusManager: def __init__(self) -> None: self._buses: dict[str, MessageBus] = {} + self._locks: dict[str, asyncio.Lock] = { + "session": asyncio.Lock(), + "system": asyncio.Lock(), + } async def get_bus(self, bus_type: str) -> MessageBus: """Get a connected bus by type name ("session" or "system").""" if bus_type not in ("session", "system"): raise ValueError(f"bus_type must be 'session' or 'system', got {bus_type!r}") - if bus_type in self._buses: - bus = self._buses[bus_type] - if bus.connected: - return bus - # stale — drop and reconnect - del self._buses[bus_type] + async with self._locks[bus_type]: + if bus_type in self._buses: + bus = self._buses[bus_type] + if bus.connected: + return bus + # stale — drop and reconnect + del self._buses[bus_type] - bt = BusType.SESSION if bus_type == "session" else BusType.SYSTEM - bus = await MessageBus(bus_type=bt).connect() - self._buses[bus_type] = bus - return bus + bt = BusType.SESSION if bus_type == "session" else BusType.SYSTEM + bus = await MessageBus(bus_type=bt).connect() + self._buses[bus_type] = bus + return bus async def disconnect_all(self) -> None: - for bus in self._buses.values(): - if bus.connected: - bus.disconnect() + for name, bus in self._buses.items(): + try: + if bus.connected: + bus.disconnect() + except Exception as exc: + print(f"mcdbus: error disconnecting {name} bus: {exc}", file=sys.stderr) self._buses.clear() +def get_mgr(ctx: Context) -> "BusManager": + """Extract the BusManager from a FastMCP tool context.""" + return ctx.request_context.lifespan_context + + def serialize_variant(value: Any) -> Any: """Recursively unwrap dbus-fast Variant objects into JSON-safe Python types.""" unpacked = unpack_variants(value) @@ -62,13 +103,19 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]: """Parse JSON args string into a Python list suitable for D-Bus call body. For 'v' (Variant) signatures, wraps values in dbus_fast.signature.Variant. + Auto-wrapping infers types for simple scalars; for complex variant values + (e.g. array of int), pass the value as {"signature": "ai", "value": [1,2,3]}. """ from dbus_fast.signature import SignatureTree, Variant if not args_json or args_json.strip() in ("", "[]", "null"): return [] - args = json.loads(args_json) + try: + args = json.loads(args_json) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid JSON in args: {exc}") from exc + if not isinstance(args, list): args = [args] @@ -76,30 +123,46 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]: return args tree = SignatureTree(signature) + expected = len(tree.types) + if len(args) != expected: + raise ValueError( + f"Signature '{signature}' expects {expected} argument(s), got {len(args)}" + ) + result = [] for i, sig_type in enumerate(tree.types): - if i >= len(args): - break val = args[i] if sig_type.signature == "v": - # Caller must wrap variant values; we auto-wrap strings as "s" - if isinstance(val, str): - val = Variant("s", val) - elif isinstance(val, bool): - val = Variant("b", val) - elif isinstance(val, int): - val = Variant("i", val) - elif isinstance(val, float): - val = Variant("d", val) - elif isinstance(val, list): - val = Variant("as", val) - elif isinstance(val, dict): - val = Variant("a{sv}", val) + val = _auto_wrap_variant(val, Variant) result.append(val) return result +def _auto_wrap_variant(val: Any, variant_cls: type) -> Any: + """Wrap a Python value in a D-Bus Variant, inferring the type signature. + + For explicit control, pass {"signature": "ai", "value": [1,2,3]}. + """ + # Explicit variant specification + if isinstance(val, dict) and "signature" in val and "value" in val and len(val) == 2: + return variant_cls(val["signature"], val["value"]) + # Auto-infer from Python type (bool must come before int — bool is a subclass of int) + if isinstance(val, bool): + return variant_cls("b", val) + if isinstance(val, str): + return variant_cls("s", val) + if isinstance(val, int): + return variant_cls("i", val) + if isinstance(val, float): + return variant_cls("d", val) + if isinstance(val, list): + return variant_cls("as", val) + if isinstance(val, dict): + return variant_cls("a{sv}", val) + return variant_cls("s", str(val)) + + async def call_bus_method( bus: MessageBus, destination: str, @@ -108,6 +171,7 @@ async def call_bus_method( member: str, signature: str = "", body: list[Any] | None = None, + timeout: float = DBUS_CALL_TIMEOUT, ) -> Any: """Send a D-Bus method call and return the unpacked response body.""" msg = Message( @@ -118,7 +182,13 @@ async def call_bus_method( signature=signature, body=body or [], ) - reply = await bus.call(msg) + try: + reply = await asyncio.wait_for(bus.call(msg), timeout=timeout) + except TimeoutError: + raise TimeoutError( + f"D-Bus call timed out after {timeout}s: " + f"{destination} {interface}.{member} at {path}" + ) from None if reply.message_type == MessageType.ERROR: raise RuntimeError(f"D-Bus error: {reply.error_name}: {reply.body}") if reply.body: diff --git a/src/mcdbus/_discovery.py b/src/mcdbus/_discovery.py index 9043c95..46803fa 100644 --- a/src/mcdbus/_discovery.py +++ b/src/mcdbus/_discovery.py @@ -1,13 +1,14 @@ """Discovery tools — list services, introspect objects, walk object trees.""" +from collections import deque + from fastmcp import Context -from mcdbus._bus import BusManager, call_bus_method +from mcdbus._bus import call_bus_method, get_mgr, validate_bus_name, validate_object_path from mcdbus._state import mcp - -def _get_mgr(ctx: Context) -> BusManager: - return ctx.request_context.lifespan_context +MAX_TREE_DEPTH = 20 +MAX_TREE_NODES = 500 @mcp.tool() @@ -22,7 +23,7 @@ async def list_services( bus: "session" or "system" include_unique: include unique connection names like :1.42 (default False) """ - mgr = _get_mgr(ctx) + mgr = get_mgr(ctx) await ctx.info(f"Connecting to {bus} bus...") connection = await mgr.get_bus(bus) @@ -61,7 +62,10 @@ async def introspect( object_path: object path (e.g. "/org/freedesktop/Notifications") include_standard: include standard D-Bus interfaces (Peer, Introspectable, Properties) """ - mgr = _get_mgr(ctx) + validate_bus_name(service) + validate_object_path(object_path) + + mgr = get_mgr(ctx) await ctx.report_progress(0, 4) connection = await mgr.get_bus(bus) @@ -102,7 +106,9 @@ async def introspect( lines.append("**Properties:**") for prop in iface.properties: access = getattr(prop.access, "name", str(prop.access)).lower() - lines.append(f"- `{prop.name}`: `{prop.signature}` ({access})") + lines.append( + f"- `{prop.name}`: `{prop.signature}` ({access})" + ) lines.append("") if iface.signals: @@ -132,40 +138,63 @@ async def list_objects( service: str, ctx: Context, root_path: str = "/", + max_depth: int = MAX_TREE_DEPTH, ) -> str: - """Recursively walk the D-Bus object tree for a service. + """Recursively walk the D-Bus object tree for a service (BFS, bounded). Args: bus: "session" or "system" service: service name root_path: starting path (default "/") + max_depth: maximum tree depth to walk (default 20) """ - mgr = _get_mgr(ctx) + validate_bus_name(service) + validate_object_path(root_path) + + mgr = get_mgr(ctx) connection = await mgr.get_bus(bus) paths: list[tuple[str, list[str]]] = [] + visited: set[str] = set() + queue: deque[tuple[str, int]] = deque([(root_path, 0)]) + truncated = False + + while queue: + path, depth = queue.popleft() + + if path in visited: + continue + visited.add(path) + + if len(paths) >= MAX_TREE_NODES: + truncated = True + break - async def _walk(path: str) -> None: try: node = await connection.introspect(service, path) except Exception as exc: await ctx.warning(f"Could not introspect {path}: {exc}") - return + continue iface_names = [i.name for i in node.interfaces] paths.append((path, iface_names)) - await ctx.report_progress(len(paths), len(paths)) + await ctx.report_progress(len(paths), None) - for child in node.nodes: - child_path = path.rstrip("/") + "/" + child.name - await _walk(child_path) + if depth < max_depth: + for child in node.nodes: + child_path = path.rstrip("/") + "/" + child.name + queue.append((child_path, depth + 1)) - await _walk(root_path) + header = f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects" + if truncated: + header += f" (truncated at {MAX_TREE_NODES})" + lines = [header + "\n"] - lines = [f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects\n"] for path, ifaces in sorted(paths): std_prefix = "org.freedesktop.DBus." - iface_str = ", ".join(f"`{i}`" for i in ifaces if not i.startswith(std_prefix)) + iface_str = ", ".join( + f"`{i}`" for i in ifaces if not i.startswith(std_prefix) + ) if iface_str: lines.append(f"- `{path}` — {iface_str}") else: diff --git a/src/mcdbus/_interaction.py b/src/mcdbus/_interaction.py index 1fcba21..c6f6c88 100644 --- a/src/mcdbus/_interaction.py +++ b/src/mcdbus/_interaction.py @@ -1,18 +1,21 @@ """Interaction tools — method calls, property get/set.""" import json +import sys from dbus_fast.signature import Variant from fastmcp import Context -from mcdbus._bus import BusManager, call_bus_method, deserialize_args +from mcdbus._bus import ( + call_bus_method, + deserialize_args, + get_mgr, + validate_bus_name, + validate_object_path, +) from mcdbus._state import mcp -def _get_mgr(ctx: Context) -> BusManager: - return ctx.request_context.lifespan_context - - @mcp.tool() async def call_method( bus: str, @@ -35,10 +38,26 @@ async def call_method( args: JSON array of arguments (default "[]") signature: D-Bus type signature for args (e.g. "su"); leave empty for no-arg methods """ - mgr = _get_mgr(ctx) + validate_bus_name(service) + validate_object_path(object_path) + validate_bus_name(interface, label="interface") + + mgr = get_mgr(ctx) connection = await mgr.get_bus(bus) - body = deserialize_args(args, signature) + try: + body = deserialize_args(args, signature) + except (ValueError, json.JSONDecodeError) as exc: + return f"Error parsing arguments: {exc}" + + # Log system bus writes to stderr for audit trail + if bus == "system": + print( + f"mcdbus: system bus call {service} {interface}.{method} " + f"path={object_path} sig={signature!r}", + file=sys.stderr, + ) + await ctx.info(f"Calling {interface}.{method} on {service}...") try: @@ -51,7 +70,7 @@ async def call_method( signature=signature, body=body, ) - except RuntimeError as exc: + except (RuntimeError, TimeoutError) as exc: await ctx.error(str(exc)) return f"Error: {exc}" @@ -83,18 +102,25 @@ async def get_property( interface: interface that owns the property property_name: property name to read """ - mgr = _get_mgr(ctx) + validate_bus_name(service) + validate_object_path(object_path) + validate_bus_name(interface, label="interface") + + mgr = get_mgr(ctx) connection = await mgr.get_bus(bus) - result = await call_bus_method( - connection, - destination=service, - path=object_path, - interface="org.freedesktop.DBus.Properties", - member="Get", - signature="ss", - body=[interface, property_name], - ) + try: + result = await call_bus_method( + connection, + destination=service, + path=object_path, + interface="org.freedesktop.DBus.Properties", + member="Get", + signature="ss", + body=[interface, property_name], + ) + except (RuntimeError, TimeoutError) as exc: + return f"Error reading {interface}.{property_name}: {exc}" if result: value = result[0] @@ -124,22 +150,40 @@ async def set_property( value: JSON-encoded value to set signature: D-Bus type signature of the property value (e.g. "b" for boolean) """ - mgr = _get_mgr(ctx) + validate_bus_name(service) + validate_object_path(object_path) + validate_bus_name(interface, label="interface") + + mgr = get_mgr(ctx) connection = await mgr.get_bus(bus) - parsed_value = json.loads(value) + try: + parsed_value = json.loads(value) + except json.JSONDecodeError as exc: + return f"Error parsing value JSON: {exc}" + variant = Variant(signature, parsed_value) - await call_bus_method( - connection, - destination=service, - path=object_path, - interface="org.freedesktop.DBus.Properties", - member="Set", - signature="ssv", - body=[interface, property_name, variant], + # Audit log for all set_property calls + print( + f"mcdbus: set_property {bus} {service} {interface}.{property_name} " + f"= {value} (sig={signature!r})", + file=sys.stderr, ) + try: + await call_bus_method( + connection, + destination=service, + path=object_path, + interface="org.freedesktop.DBus.Properties", + member="Set", + signature="ssv", + body=[interface, property_name, variant], + ) + except (RuntimeError, TimeoutError) as exc: + return f"Error setting {interface}.{property_name}: {exc}" + return f"Set {interface}.{property_name} = {value}" @@ -159,18 +203,25 @@ async def get_all_properties( object_path: object path interface: interface to read properties from """ - mgr = _get_mgr(ctx) + validate_bus_name(service) + validate_object_path(object_path) + validate_bus_name(interface, label="interface") + + mgr = get_mgr(ctx) connection = await mgr.get_bus(bus) - result = await call_bus_method( - connection, - destination=service, - path=object_path, - interface="org.freedesktop.DBus.Properties", - member="GetAll", - signature="s", - body=[interface], - ) + try: + result = await call_bus_method( + connection, + destination=service, + path=object_path, + interface="org.freedesktop.DBus.Properties", + member="GetAll", + signature="s", + body=[interface], + ) + except (RuntimeError, TimeoutError) as exc: + return f"Error reading properties on {interface}: {exc}" if not result or not result[0]: return "No properties found." diff --git a/src/mcdbus/_resources.py b/src/mcdbus/_resources.py index 4cf73dd..683b5c3 100644 --- a/src/mcdbus/_resources.py +++ b/src/mcdbus/_resources.py @@ -1,15 +1,17 @@ """Dynamic MCP resources for browsing D-Bus services.""" +from collections import deque from urllib.parse import unquote from mcdbus._bus import BusManager, call_bus_method from mcdbus._state import mcp +MAX_RESOURCE_NODES = 200 + @mcp.resource("dbus://{bus}/services") async def bus_services(bus: str) -> str: """Live list of well-known service names on a D-Bus bus.""" - # Resources don't get Context — we need a fresh connection mgr = BusManager() try: connection = await mgr.get_bus(bus) @@ -28,23 +30,31 @@ async def bus_services(bus: str) -> str: @mcp.resource("dbus://{bus}/{service}/objects") async def service_objects(bus: str, service: str) -> str: - """Object tree for a D-Bus service.""" + """Object tree for a D-Bus service (bounded BFS walk).""" mgr = BusManager() try: connection = await mgr.get_bus(bus) paths: list[str] = [] + visited: set[str] = set() + queue: deque[tuple[str, int]] = deque([("/", 0)]) + + while queue and len(paths) < MAX_RESOURCE_NODES: + path, depth = queue.popleft() + if path in visited: + continue + visited.add(path) - async def _walk(path: str) -> None: try: node = await connection.introspect(service, path) - paths.append(path) + except (RuntimeError, OSError): + continue + + paths.append(path) + if depth < 15: for child in node.nodes: child_path = path.rstrip("/") + "/" + child.name - await _walk(child_path) - except Exception: - pass + queue.append((child_path, depth + 1)) - await _walk("/") return "\n".join(sorted(paths)) finally: await mgr.disconnect_all() @@ -53,7 +63,6 @@ async def service_objects(bus: str, service: str) -> str: @mcp.resource("dbus://{bus}/{service}/{path}/interfaces") async def object_interfaces(bus: str, service: str, path: str) -> str: """Interfaces available at a D-Bus object path.""" - # path comes URL-encoded (slashes → %2F), decode it decoded_path = unquote(path) if not decoded_path.startswith("/"): decoded_path = "/" + decoded_path diff --git a/src/mcdbus/_shortcuts.py b/src/mcdbus/_shortcuts.py index ec404fa..1a83a22 100644 --- a/src/mcdbus/_shortcuts.py +++ b/src/mcdbus/_shortcuts.py @@ -1,15 +1,13 @@ """High-level convenience tools for common D-Bus operations.""" +import fnmatch + from fastmcp import Context -from mcdbus._bus import BusManager, call_bus_method +from mcdbus._bus import call_bus_method, get_mgr from mcdbus._state import mcp -def _get_mgr(ctx: Context) -> BusManager: - return ctx.request_context.lifespan_context - - @mcp.tool() async def send_notification( summary: str, @@ -26,27 +24,30 @@ async def send_notification( icon: icon name or path (empty for default) timeout: display duration in milliseconds (default 5000) """ - mgr = _get_mgr(ctx) + mgr = get_mgr(ctx) connection = await mgr.get_bus("session") - result = await call_bus_method( - connection, - destination="org.freedesktop.Notifications", - path="/org/freedesktop/Notifications", - interface="org.freedesktop.Notifications", - member="Notify", - signature="susssasa{sv}i", - body=[ - "mcdbus", # app_name - 0, # replaces_id (0 = new) - icon, # app_icon - summary, # summary - body, # body - [], # actions - {}, # hints - timeout, # expire_timeout - ], - ) + try: + result = await call_bus_method( + connection, + destination="org.freedesktop.Notifications", + path="/org/freedesktop/Notifications", + interface="org.freedesktop.Notifications", + member="Notify", + signature="susssasa{sv}i", + body=[ + "mcdbus", # app_name + 0, # replaces_id (0 = new) + icon, # app_icon + summary, # summary + body, # body + [], # actions + {}, # hints + timeout, # expire_timeout + ], + ) + except (RuntimeError, TimeoutError) as exc: + return f"Error sending notification: {exc}" nid = result[0] if result else "unknown" return f"Notification sent (id: {nid})" @@ -64,16 +65,19 @@ async def list_systemd_units( bus: "session" (user units) or "system" (system units, default) pattern: optional glob filter (e.g. "docker*", "*.service") """ - mgr = _get_mgr(ctx) + mgr = get_mgr(ctx) connection = await mgr.get_bus(bus) - result = await call_bus_method( - connection, - destination="org.freedesktop.systemd1", - path="/org/freedesktop/systemd1", - interface="org.freedesktop.systemd1.Manager", - member="ListUnits", - ) + try: + result = await call_bus_method( + connection, + destination="org.freedesktop.systemd1", + path="/org/freedesktop/systemd1", + interface="org.freedesktop.systemd1.Manager", + member="ListUnits", + ) + except (RuntimeError, TimeoutError) as exc: + return f"Error listing units: {exc}" if not result or not result[0]: return "No units found." @@ -81,7 +85,6 @@ async def list_systemd_units( units = result[0] # Each unit is (name, description, load_state, active_state, sub_state, ...) - import fnmatch if pattern: units = [u for u in units if fnmatch.fnmatch(u[0], pattern)] @@ -107,11 +110,10 @@ async def media_player_control( action: one of "play", "pause", "next", "previous", "stop", "play-pause" player: MPRIS player name (auto-discovers first player if empty) """ - mgr = _get_mgr(ctx) + mgr = get_mgr(ctx) connection = await mgr.get_bus("session") if not player: - # Discover MPRIS players result = await call_bus_method( connection, destination="org.freedesktop.DBus", @@ -120,12 +122,18 @@ async def media_player_control( member="ListNames", ) names = result[0] if result else [] - mpris_services = [n for n in names if n.startswith("org.mpris.MediaPlayer2.")] + mpris_services = sorted( + n for n in names if n.startswith("org.mpris.MediaPlayer2.") + ) if not mpris_services: return "No MPRIS media players found on session bus." player = mpris_services[0] - await ctx.info(f"Auto-discovered player: {player}") + others = ( + f"\nOther players: {', '.join(mpris_services[1:])}" + if len(mpris_services) > 1 else "" + ) + await ctx.info(f"Auto-discovered player: {player}{others}") action_map = { "play": "Play", @@ -140,13 +148,16 @@ async def media_player_control( if not dbus_method: return f"Unknown action: {action}. Use: {', '.join(action_map.keys())}" - await call_bus_method( - connection, - destination=player, - path="/org/mpris/MediaPlayer2", - interface="org.mpris.MediaPlayer2.Player", - member=dbus_method, - ) + try: + await call_bus_method( + connection, + destination=player, + path="/org/mpris/MediaPlayer2", + interface="org.mpris.MediaPlayer2.Player", + member=dbus_method, + ) + except (RuntimeError, TimeoutError) as exc: + return f"Error controlling player: {exc}" # Read current playback status try: @@ -160,7 +171,7 @@ async def media_player_control( body=["org.mpris.MediaPlayer2.Player", "PlaybackStatus"], ) status = status_result[0] if status_result else "unknown" - except Exception: + except (RuntimeError, TimeoutError): status = "unknown" # Try to get current track metadata @@ -179,7 +190,7 @@ async def media_player_control( artist = metadata.get("xesam:artist", ["Unknown"]) if isinstance(artist, list): artist = ", ".join(artist) - except Exception: + except (RuntimeError, TimeoutError): title, artist = "Unknown", "Unknown" return f"Player: {player}\nAction: {action}\nStatus: {status}\nNow: {artist} — {title}" diff --git a/tests/test_bus.py b/tests/test_bus.py index dc99a3a..edcbeba 100644 --- a/tests/test_bus.py +++ b/tests/test_bus.py @@ -2,7 +2,13 @@ import pytest -from mcdbus._bus import BusManager, deserialize_args, serialize_variant +from mcdbus._bus import ( + BusManager, + deserialize_args, + serialize_variant, + validate_bus_name, + validate_object_path, +) class TestBusManager: @@ -71,3 +77,37 @@ class TestDeserializeArgs: def test_single_value_wraps_in_list(self): result = deserialize_args('"hello"', "s") assert result == ["hello"] + + def test_arg_count_mismatch_raises(self): + with pytest.raises(ValueError, match="expects 1 argument"): + deserialize_args('["hello", "extra"]', "s") + + def test_too_few_args_raises(self): + with pytest.raises(ValueError, match="expects 2 argument"): + deserialize_args('["hello"]', "su") + + def test_invalid_json_raises(self): + with pytest.raises(ValueError, match="Invalid JSON"): + deserialize_args("{broken", "s") + + +class TestValidation: + def test_valid_bus_name(self): + validate_bus_name("org.freedesktop.DBus") + validate_bus_name("org.kde.KWin") + + def test_invalid_bus_name(self): + with pytest.raises(ValueError, match="Invalid D-Bus"): + validate_bus_name("not-a-bus-name") + with pytest.raises(ValueError, match="Invalid D-Bus"): + validate_bus_name("") + + def test_valid_object_path(self): + validate_object_path("/") + validate_object_path("/org/freedesktop/DBus") + + def test_invalid_object_path(self): + with pytest.raises(ValueError, match="Invalid D-Bus object path"): + validate_object_path("not/a/path") + with pytest.raises(ValueError, match="Invalid D-Bus object path"): + validate_object_path("")