Harden all tool/resource paths per Hamilton review

- Add asyncio.wait_for timeout (30s) to all D-Bus calls
- Add asyncio.Lock to BusManager.get_bus to prevent race conditions
- Convert recursive tree walk to bounded BFS with visited set, max depth (20), max nodes (500)
- Add input validation for D-Bus names and object paths at tool boundary
- Standardize error handling (RuntimeError, TimeoutError) across all tools
- Catch JSONDecodeError in deserialize_args and set_property
- Check arg count matches signature in deserialize_args
- Add explicit variant wrapping via {"signature": "ai", "value": [1,2,3]}
- Narrow resource exception handling from Exception to (RuntimeError, OSError)
- Guard disconnect_all per-bus to avoid partial cleanup on error
- Audit log system bus calls and all set_property calls to stderr
- Consolidate _get_mgr into get_mgr in _bus.py
- Sort MPRIS players for deterministic auto-discovery
- Fix progress reporting (pass None as total when unknown)
- Add 7 new tests for validation, arg count, and JSON error paths (32 total)
This commit is contained in:
Ryan Malloy 2026-03-05 20:31:06 -07:00
parent 4d7b73f6ee
commit fa41ffbf80
6 changed files with 352 additions and 142 deletions

View File

@ -1,11 +1,39 @@
"""Bus connection management and D-Bus type serialization.""" """Bus connection management and D-Bus type serialization."""
import asyncio
import json import json
import re
import sys
from typing import Any from typing import Any
from dbus_fast import BusType, Message, MessageType from dbus_fast import BusType, Message, MessageType
from dbus_fast.aio import MessageBus from dbus_fast.aio import MessageBus
from dbus_fast.unpack import unpack_variants from dbus_fast.unpack import unpack_variants
from fastmcp import Context
DBUS_CALL_TIMEOUT = float(30) # seconds; override via env MCDBUS_TIMEOUT
# D-Bus spec: https://dbus.freedesktop.org/doc/dbus-specification.html
_DBUS_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_.]*(\.[A-Za-z_][A-Za-z0-9_]*)+$")
_DBUS_PATH_RE = re.compile(r"^/$|^(/[A-Za-z0-9_]+)+$")
def validate_bus_name(name: str, label: str = "service") -> None:
"""Validate a D-Bus well-known or interface name."""
if not _DBUS_NAME_RE.match(name):
raise ValueError(
f"Invalid D-Bus {label}: {name!r}. "
f"Must match [A-Za-z_][A-Za-z0-9_]*(.[A-Za-z_][A-Za-z0-9_]*)+"
)
def validate_object_path(path: str) -> None:
"""Validate a D-Bus object path."""
if not _DBUS_PATH_RE.match(path):
raise ValueError(
f"Invalid D-Bus object path: {path!r}. "
f"Must be '/' or '/element/element/...' with [A-Za-z0-9_]"
)
class BusManager: class BusManager:
@ -13,12 +41,17 @@ class BusManager:
def __init__(self) -> None: def __init__(self) -> None:
self._buses: dict[str, MessageBus] = {} self._buses: dict[str, MessageBus] = {}
self._locks: dict[str, asyncio.Lock] = {
"session": asyncio.Lock(),
"system": asyncio.Lock(),
}
async def get_bus(self, bus_type: str) -> MessageBus: async def get_bus(self, bus_type: str) -> MessageBus:
"""Get a connected bus by type name ("session" or "system").""" """Get a connected bus by type name ("session" or "system")."""
if bus_type not in ("session", "system"): if bus_type not in ("session", "system"):
raise ValueError(f"bus_type must be 'session' or 'system', got {bus_type!r}") raise ValueError(f"bus_type must be 'session' or 'system', got {bus_type!r}")
async with self._locks[bus_type]:
if bus_type in self._buses: if bus_type in self._buses:
bus = self._buses[bus_type] bus = self._buses[bus_type]
if bus.connected: if bus.connected:
@ -32,12 +65,20 @@ class BusManager:
return bus return bus
async def disconnect_all(self) -> None: async def disconnect_all(self) -> None:
for bus in self._buses.values(): for name, bus in self._buses.items():
try:
if bus.connected: if bus.connected:
bus.disconnect() bus.disconnect()
except Exception as exc:
print(f"mcdbus: error disconnecting {name} bus: {exc}", file=sys.stderr)
self._buses.clear() self._buses.clear()
def get_mgr(ctx: Context) -> "BusManager":
"""Extract the BusManager from a FastMCP tool context."""
return ctx.request_context.lifespan_context
def serialize_variant(value: Any) -> Any: def serialize_variant(value: Any) -> Any:
"""Recursively unwrap dbus-fast Variant objects into JSON-safe Python types.""" """Recursively unwrap dbus-fast Variant objects into JSON-safe Python types."""
unpacked = unpack_variants(value) unpacked = unpack_variants(value)
@ -62,13 +103,19 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]:
"""Parse JSON args string into a Python list suitable for D-Bus call body. """Parse JSON args string into a Python list suitable for D-Bus call body.
For 'v' (Variant) signatures, wraps values in dbus_fast.signature.Variant. For 'v' (Variant) signatures, wraps values in dbus_fast.signature.Variant.
Auto-wrapping infers types for simple scalars; for complex variant values
(e.g. array of int), pass the value as {"signature": "ai", "value": [1,2,3]}.
""" """
from dbus_fast.signature import SignatureTree, Variant from dbus_fast.signature import SignatureTree, Variant
if not args_json or args_json.strip() in ("", "[]", "null"): if not args_json or args_json.strip() in ("", "[]", "null"):
return [] return []
try:
args = json.loads(args_json) args = json.loads(args_json)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON in args: {exc}") from exc
if not isinstance(args, list): if not isinstance(args, list):
args = [args] args = [args]
@ -76,30 +123,46 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]:
return args return args
tree = SignatureTree(signature) tree = SignatureTree(signature)
expected = len(tree.types)
if len(args) != expected:
raise ValueError(
f"Signature '{signature}' expects {expected} argument(s), got {len(args)}"
)
result = [] result = []
for i, sig_type in enumerate(tree.types): for i, sig_type in enumerate(tree.types):
if i >= len(args):
break
val = args[i] val = args[i]
if sig_type.signature == "v": if sig_type.signature == "v":
# Caller must wrap variant values; we auto-wrap strings as "s" val = _auto_wrap_variant(val, Variant)
if isinstance(val, str):
val = Variant("s", val)
elif isinstance(val, bool):
val = Variant("b", val)
elif isinstance(val, int):
val = Variant("i", val)
elif isinstance(val, float):
val = Variant("d", val)
elif isinstance(val, list):
val = Variant("as", val)
elif isinstance(val, dict):
val = Variant("a{sv}", val)
result.append(val) result.append(val)
return result return result
def _auto_wrap_variant(val: Any, variant_cls: type) -> Any:
"""Wrap a Python value in a D-Bus Variant, inferring the type signature.
For explicit control, pass {"signature": "ai", "value": [1,2,3]}.
"""
# Explicit variant specification
if isinstance(val, dict) and "signature" in val and "value" in val and len(val) == 2:
return variant_cls(val["signature"], val["value"])
# Auto-infer from Python type (bool must come before int — bool is a subclass of int)
if isinstance(val, bool):
return variant_cls("b", val)
if isinstance(val, str):
return variant_cls("s", val)
if isinstance(val, int):
return variant_cls("i", val)
if isinstance(val, float):
return variant_cls("d", val)
if isinstance(val, list):
return variant_cls("as", val)
if isinstance(val, dict):
return variant_cls("a{sv}", val)
return variant_cls("s", str(val))
async def call_bus_method( async def call_bus_method(
bus: MessageBus, bus: MessageBus,
destination: str, destination: str,
@ -108,6 +171,7 @@ async def call_bus_method(
member: str, member: str,
signature: str = "", signature: str = "",
body: list[Any] | None = None, body: list[Any] | None = None,
timeout: float = DBUS_CALL_TIMEOUT,
) -> Any: ) -> Any:
"""Send a D-Bus method call and return the unpacked response body.""" """Send a D-Bus method call and return the unpacked response body."""
msg = Message( msg = Message(
@ -118,7 +182,13 @@ async def call_bus_method(
signature=signature, signature=signature,
body=body or [], body=body or [],
) )
reply = await bus.call(msg) try:
reply = await asyncio.wait_for(bus.call(msg), timeout=timeout)
except TimeoutError:
raise TimeoutError(
f"D-Bus call timed out after {timeout}s: "
f"{destination} {interface}.{member} at {path}"
) from None
if reply.message_type == MessageType.ERROR: if reply.message_type == MessageType.ERROR:
raise RuntimeError(f"D-Bus error: {reply.error_name}: {reply.body}") raise RuntimeError(f"D-Bus error: {reply.error_name}: {reply.body}")
if reply.body: if reply.body:

View File

@ -1,13 +1,14 @@
"""Discovery tools — list services, introspect objects, walk object trees.""" """Discovery tools — list services, introspect objects, walk object trees."""
from collections import deque
from fastmcp import Context from fastmcp import Context
from mcdbus._bus import BusManager, call_bus_method from mcdbus._bus import call_bus_method, get_mgr, validate_bus_name, validate_object_path
from mcdbus._state import mcp from mcdbus._state import mcp
MAX_TREE_DEPTH = 20
def _get_mgr(ctx: Context) -> BusManager: MAX_TREE_NODES = 500
return ctx.request_context.lifespan_context
@mcp.tool() @mcp.tool()
@ -22,7 +23,7 @@ async def list_services(
bus: "session" or "system" bus: "session" or "system"
include_unique: include unique connection names like :1.42 (default False) include_unique: include unique connection names like :1.42 (default False)
""" """
mgr = _get_mgr(ctx) mgr = get_mgr(ctx)
await ctx.info(f"Connecting to {bus} bus...") await ctx.info(f"Connecting to {bus} bus...")
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
@ -61,7 +62,10 @@ async def introspect(
object_path: object path (e.g. "/org/freedesktop/Notifications") object_path: object path (e.g. "/org/freedesktop/Notifications")
include_standard: include standard D-Bus interfaces (Peer, Introspectable, Properties) include_standard: include standard D-Bus interfaces (Peer, Introspectable, Properties)
""" """
mgr = _get_mgr(ctx) validate_bus_name(service)
validate_object_path(object_path)
mgr = get_mgr(ctx)
await ctx.report_progress(0, 4) await ctx.report_progress(0, 4)
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
@ -102,7 +106,9 @@ async def introspect(
lines.append("**Properties:**") lines.append("**Properties:**")
for prop in iface.properties: for prop in iface.properties:
access = getattr(prop.access, "name", str(prop.access)).lower() access = getattr(prop.access, "name", str(prop.access)).lower()
lines.append(f"- `{prop.name}`: `{prop.signature}` ({access})") lines.append(
f"- `{prop.name}`: `{prop.signature}` ({access})"
)
lines.append("") lines.append("")
if iface.signals: if iface.signals:
@ -132,40 +138,63 @@ async def list_objects(
service: str, service: str,
ctx: Context, ctx: Context,
root_path: str = "/", root_path: str = "/",
max_depth: int = MAX_TREE_DEPTH,
) -> str: ) -> str:
"""Recursively walk the D-Bus object tree for a service. """Recursively walk the D-Bus object tree for a service (BFS, bounded).
Args: Args:
bus: "session" or "system" bus: "session" or "system"
service: service name service: service name
root_path: starting path (default "/") root_path: starting path (default "/")
max_depth: maximum tree depth to walk (default 20)
""" """
mgr = _get_mgr(ctx) validate_bus_name(service)
validate_object_path(root_path)
mgr = get_mgr(ctx)
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
paths: list[tuple[str, list[str]]] = [] paths: list[tuple[str, list[str]]] = []
visited: set[str] = set()
queue: deque[tuple[str, int]] = deque([(root_path, 0)])
truncated = False
while queue:
path, depth = queue.popleft()
if path in visited:
continue
visited.add(path)
if len(paths) >= MAX_TREE_NODES:
truncated = True
break
async def _walk(path: str) -> None:
try: try:
node = await connection.introspect(service, path) node = await connection.introspect(service, path)
except Exception as exc: except Exception as exc:
await ctx.warning(f"Could not introspect {path}: {exc}") await ctx.warning(f"Could not introspect {path}: {exc}")
return continue
iface_names = [i.name for i in node.interfaces] iface_names = [i.name for i in node.interfaces]
paths.append((path, iface_names)) paths.append((path, iface_names))
await ctx.report_progress(len(paths), len(paths)) await ctx.report_progress(len(paths), None)
if depth < max_depth:
for child in node.nodes: for child in node.nodes:
child_path = path.rstrip("/") + "/" + child.name child_path = path.rstrip("/") + "/" + child.name
await _walk(child_path) queue.append((child_path, depth + 1))
await _walk(root_path) header = f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects"
if truncated:
header += f" (truncated at {MAX_TREE_NODES})"
lines = [header + "\n"]
lines = [f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects\n"]
for path, ifaces in sorted(paths): for path, ifaces in sorted(paths):
std_prefix = "org.freedesktop.DBus." std_prefix = "org.freedesktop.DBus."
iface_str = ", ".join(f"`{i}`" for i in ifaces if not i.startswith(std_prefix)) iface_str = ", ".join(
f"`{i}`" for i in ifaces if not i.startswith(std_prefix)
)
if iface_str: if iface_str:
lines.append(f"- `{path}` — {iface_str}") lines.append(f"- `{path}` — {iface_str}")
else: else:

View File

@ -1,18 +1,21 @@
"""Interaction tools — method calls, property get/set.""" """Interaction tools — method calls, property get/set."""
import json import json
import sys
from dbus_fast.signature import Variant from dbus_fast.signature import Variant
from fastmcp import Context from fastmcp import Context
from mcdbus._bus import BusManager, call_bus_method, deserialize_args from mcdbus._bus import (
call_bus_method,
deserialize_args,
get_mgr,
validate_bus_name,
validate_object_path,
)
from mcdbus._state import mcp from mcdbus._state import mcp
def _get_mgr(ctx: Context) -> BusManager:
return ctx.request_context.lifespan_context
@mcp.tool() @mcp.tool()
async def call_method( async def call_method(
bus: str, bus: str,
@ -35,10 +38,26 @@ async def call_method(
args: JSON array of arguments (default "[]") args: JSON array of arguments (default "[]")
signature: D-Bus type signature for args (e.g. "su"); leave empty for no-arg methods signature: D-Bus type signature for args (e.g. "su"); leave empty for no-arg methods
""" """
mgr = _get_mgr(ctx) validate_bus_name(service)
validate_object_path(object_path)
validate_bus_name(interface, label="interface")
mgr = get_mgr(ctx)
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
try:
body = deserialize_args(args, signature) body = deserialize_args(args, signature)
except (ValueError, json.JSONDecodeError) as exc:
return f"Error parsing arguments: {exc}"
# Log system bus writes to stderr for audit trail
if bus == "system":
print(
f"mcdbus: system bus call {service} {interface}.{method} "
f"path={object_path} sig={signature!r}",
file=sys.stderr,
)
await ctx.info(f"Calling {interface}.{method} on {service}...") await ctx.info(f"Calling {interface}.{method} on {service}...")
try: try:
@ -51,7 +70,7 @@ async def call_method(
signature=signature, signature=signature,
body=body, body=body,
) )
except RuntimeError as exc: except (RuntimeError, TimeoutError) as exc:
await ctx.error(str(exc)) await ctx.error(str(exc))
return f"Error: {exc}" return f"Error: {exc}"
@ -83,9 +102,14 @@ async def get_property(
interface: interface that owns the property interface: interface that owns the property
property_name: property name to read property_name: property name to read
""" """
mgr = _get_mgr(ctx) validate_bus_name(service)
validate_object_path(object_path)
validate_bus_name(interface, label="interface")
mgr = get_mgr(ctx)
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
try:
result = await call_bus_method( result = await call_bus_method(
connection, connection,
destination=service, destination=service,
@ -95,6 +119,8 @@ async def get_property(
signature="ss", signature="ss",
body=[interface, property_name], body=[interface, property_name],
) )
except (RuntimeError, TimeoutError) as exc:
return f"Error reading {interface}.{property_name}: {exc}"
if result: if result:
value = result[0] value = result[0]
@ -124,12 +150,28 @@ async def set_property(
value: JSON-encoded value to set value: JSON-encoded value to set
signature: D-Bus type signature of the property value (e.g. "b" for boolean) signature: D-Bus type signature of the property value (e.g. "b" for boolean)
""" """
mgr = _get_mgr(ctx) validate_bus_name(service)
validate_object_path(object_path)
validate_bus_name(interface, label="interface")
mgr = get_mgr(ctx)
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
try:
parsed_value = json.loads(value) parsed_value = json.loads(value)
except json.JSONDecodeError as exc:
return f"Error parsing value JSON: {exc}"
variant = Variant(signature, parsed_value) variant = Variant(signature, parsed_value)
# Audit log for all set_property calls
print(
f"mcdbus: set_property {bus} {service} {interface}.{property_name} "
f"= {value} (sig={signature!r})",
file=sys.stderr,
)
try:
await call_bus_method( await call_bus_method(
connection, connection,
destination=service, destination=service,
@ -139,6 +181,8 @@ async def set_property(
signature="ssv", signature="ssv",
body=[interface, property_name, variant], body=[interface, property_name, variant],
) )
except (RuntimeError, TimeoutError) as exc:
return f"Error setting {interface}.{property_name}: {exc}"
return f"Set {interface}.{property_name} = {value}" return f"Set {interface}.{property_name} = {value}"
@ -159,9 +203,14 @@ async def get_all_properties(
object_path: object path object_path: object path
interface: interface to read properties from interface: interface to read properties from
""" """
mgr = _get_mgr(ctx) validate_bus_name(service)
validate_object_path(object_path)
validate_bus_name(interface, label="interface")
mgr = get_mgr(ctx)
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
try:
result = await call_bus_method( result = await call_bus_method(
connection, connection,
destination=service, destination=service,
@ -171,6 +220,8 @@ async def get_all_properties(
signature="s", signature="s",
body=[interface], body=[interface],
) )
except (RuntimeError, TimeoutError) as exc:
return f"Error reading properties on {interface}: {exc}"
if not result or not result[0]: if not result or not result[0]:
return "No properties found." return "No properties found."

View File

@ -1,15 +1,17 @@
"""Dynamic MCP resources for browsing D-Bus services.""" """Dynamic MCP resources for browsing D-Bus services."""
from collections import deque
from urllib.parse import unquote from urllib.parse import unquote
from mcdbus._bus import BusManager, call_bus_method from mcdbus._bus import BusManager, call_bus_method
from mcdbus._state import mcp from mcdbus._state import mcp
MAX_RESOURCE_NODES = 200
@mcp.resource("dbus://{bus}/services") @mcp.resource("dbus://{bus}/services")
async def bus_services(bus: str) -> str: async def bus_services(bus: str) -> str:
"""Live list of well-known service names on a D-Bus bus.""" """Live list of well-known service names on a D-Bus bus."""
# Resources don't get Context — we need a fresh connection
mgr = BusManager() mgr = BusManager()
try: try:
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
@ -28,23 +30,31 @@ async def bus_services(bus: str) -> str:
@mcp.resource("dbus://{bus}/{service}/objects") @mcp.resource("dbus://{bus}/{service}/objects")
async def service_objects(bus: str, service: str) -> str: async def service_objects(bus: str, service: str) -> str:
"""Object tree for a D-Bus service.""" """Object tree for a D-Bus service (bounded BFS walk)."""
mgr = BusManager() mgr = BusManager()
try: try:
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
paths: list[str] = [] paths: list[str] = []
visited: set[str] = set()
queue: deque[tuple[str, int]] = deque([("/", 0)])
while queue and len(paths) < MAX_RESOURCE_NODES:
path, depth = queue.popleft()
if path in visited:
continue
visited.add(path)
async def _walk(path: str) -> None:
try: try:
node = await connection.introspect(service, path) node = await connection.introspect(service, path)
except (RuntimeError, OSError):
continue
paths.append(path) paths.append(path)
if depth < 15:
for child in node.nodes: for child in node.nodes:
child_path = path.rstrip("/") + "/" + child.name child_path = path.rstrip("/") + "/" + child.name
await _walk(child_path) queue.append((child_path, depth + 1))
except Exception:
pass
await _walk("/")
return "\n".join(sorted(paths)) return "\n".join(sorted(paths))
finally: finally:
await mgr.disconnect_all() await mgr.disconnect_all()
@ -53,7 +63,6 @@ async def service_objects(bus: str, service: str) -> str:
@mcp.resource("dbus://{bus}/{service}/{path}/interfaces") @mcp.resource("dbus://{bus}/{service}/{path}/interfaces")
async def object_interfaces(bus: str, service: str, path: str) -> str: async def object_interfaces(bus: str, service: str, path: str) -> str:
"""Interfaces available at a D-Bus object path.""" """Interfaces available at a D-Bus object path."""
# path comes URL-encoded (slashes → %2F), decode it
decoded_path = unquote(path) decoded_path = unquote(path)
if not decoded_path.startswith("/"): if not decoded_path.startswith("/"):
decoded_path = "/" + decoded_path decoded_path = "/" + decoded_path

View File

@ -1,15 +1,13 @@
"""High-level convenience tools for common D-Bus operations.""" """High-level convenience tools for common D-Bus operations."""
import fnmatch
from fastmcp import Context from fastmcp import Context
from mcdbus._bus import BusManager, call_bus_method from mcdbus._bus import call_bus_method, get_mgr
from mcdbus._state import mcp from mcdbus._state import mcp
def _get_mgr(ctx: Context) -> BusManager:
return ctx.request_context.lifespan_context
@mcp.tool() @mcp.tool()
async def send_notification( async def send_notification(
summary: str, summary: str,
@ -26,9 +24,10 @@ async def send_notification(
icon: icon name or path (empty for default) icon: icon name or path (empty for default)
timeout: display duration in milliseconds (default 5000) timeout: display duration in milliseconds (default 5000)
""" """
mgr = _get_mgr(ctx) mgr = get_mgr(ctx)
connection = await mgr.get_bus("session") connection = await mgr.get_bus("session")
try:
result = await call_bus_method( result = await call_bus_method(
connection, connection,
destination="org.freedesktop.Notifications", destination="org.freedesktop.Notifications",
@ -47,6 +46,8 @@ async def send_notification(
timeout, # expire_timeout timeout, # expire_timeout
], ],
) )
except (RuntimeError, TimeoutError) as exc:
return f"Error sending notification: {exc}"
nid = result[0] if result else "unknown" nid = result[0] if result else "unknown"
return f"Notification sent (id: {nid})" return f"Notification sent (id: {nid})"
@ -64,9 +65,10 @@ async def list_systemd_units(
bus: "session" (user units) or "system" (system units, default) bus: "session" (user units) or "system" (system units, default)
pattern: optional glob filter (e.g. "docker*", "*.service") pattern: optional glob filter (e.g. "docker*", "*.service")
""" """
mgr = _get_mgr(ctx) mgr = get_mgr(ctx)
connection = await mgr.get_bus(bus) connection = await mgr.get_bus(bus)
try:
result = await call_bus_method( result = await call_bus_method(
connection, connection,
destination="org.freedesktop.systemd1", destination="org.freedesktop.systemd1",
@ -74,6 +76,8 @@ async def list_systemd_units(
interface="org.freedesktop.systemd1.Manager", interface="org.freedesktop.systemd1.Manager",
member="ListUnits", member="ListUnits",
) )
except (RuntimeError, TimeoutError) as exc:
return f"Error listing units: {exc}"
if not result or not result[0]: if not result or not result[0]:
return "No units found." return "No units found."
@ -81,7 +85,6 @@ async def list_systemd_units(
units = result[0] units = result[0]
# Each unit is (name, description, load_state, active_state, sub_state, ...) # Each unit is (name, description, load_state, active_state, sub_state, ...)
import fnmatch
if pattern: if pattern:
units = [u for u in units if fnmatch.fnmatch(u[0], pattern)] units = [u for u in units if fnmatch.fnmatch(u[0], pattern)]
@ -107,11 +110,10 @@ async def media_player_control(
action: one of "play", "pause", "next", "previous", "stop", "play-pause" action: one of "play", "pause", "next", "previous", "stop", "play-pause"
player: MPRIS player name (auto-discovers first player if empty) player: MPRIS player name (auto-discovers first player if empty)
""" """
mgr = _get_mgr(ctx) mgr = get_mgr(ctx)
connection = await mgr.get_bus("session") connection = await mgr.get_bus("session")
if not player: if not player:
# Discover MPRIS players
result = await call_bus_method( result = await call_bus_method(
connection, connection,
destination="org.freedesktop.DBus", destination="org.freedesktop.DBus",
@ -120,12 +122,18 @@ async def media_player_control(
member="ListNames", member="ListNames",
) )
names = result[0] if result else [] names = result[0] if result else []
mpris_services = [n for n in names if n.startswith("org.mpris.MediaPlayer2.")] mpris_services = sorted(
n for n in names if n.startswith("org.mpris.MediaPlayer2.")
)
if not mpris_services: if not mpris_services:
return "No MPRIS media players found on session bus." return "No MPRIS media players found on session bus."
player = mpris_services[0] player = mpris_services[0]
await ctx.info(f"Auto-discovered player: {player}") others = (
f"\nOther players: {', '.join(mpris_services[1:])}"
if len(mpris_services) > 1 else ""
)
await ctx.info(f"Auto-discovered player: {player}{others}")
action_map = { action_map = {
"play": "Play", "play": "Play",
@ -140,6 +148,7 @@ async def media_player_control(
if not dbus_method: if not dbus_method:
return f"Unknown action: {action}. Use: {', '.join(action_map.keys())}" return f"Unknown action: {action}. Use: {', '.join(action_map.keys())}"
try:
await call_bus_method( await call_bus_method(
connection, connection,
destination=player, destination=player,
@ -147,6 +156,8 @@ async def media_player_control(
interface="org.mpris.MediaPlayer2.Player", interface="org.mpris.MediaPlayer2.Player",
member=dbus_method, member=dbus_method,
) )
except (RuntimeError, TimeoutError) as exc:
return f"Error controlling player: {exc}"
# Read current playback status # Read current playback status
try: try:
@ -160,7 +171,7 @@ async def media_player_control(
body=["org.mpris.MediaPlayer2.Player", "PlaybackStatus"], body=["org.mpris.MediaPlayer2.Player", "PlaybackStatus"],
) )
status = status_result[0] if status_result else "unknown" status = status_result[0] if status_result else "unknown"
except Exception: except (RuntimeError, TimeoutError):
status = "unknown" status = "unknown"
# Try to get current track metadata # Try to get current track metadata
@ -179,7 +190,7 @@ async def media_player_control(
artist = metadata.get("xesam:artist", ["Unknown"]) artist = metadata.get("xesam:artist", ["Unknown"])
if isinstance(artist, list): if isinstance(artist, list):
artist = ", ".join(artist) artist = ", ".join(artist)
except Exception: except (RuntimeError, TimeoutError):
title, artist = "Unknown", "Unknown" title, artist = "Unknown", "Unknown"
return f"Player: {player}\nAction: {action}\nStatus: {status}\nNow: {artist}{title}" return f"Player: {player}\nAction: {action}\nStatus: {status}\nNow: {artist}{title}"

View File

@ -2,7 +2,13 @@
import pytest import pytest
from mcdbus._bus import BusManager, deserialize_args, serialize_variant from mcdbus._bus import (
BusManager,
deserialize_args,
serialize_variant,
validate_bus_name,
validate_object_path,
)
class TestBusManager: class TestBusManager:
@ -71,3 +77,37 @@ class TestDeserializeArgs:
def test_single_value_wraps_in_list(self): def test_single_value_wraps_in_list(self):
result = deserialize_args('"hello"', "s") result = deserialize_args('"hello"', "s")
assert result == ["hello"] assert result == ["hello"]
def test_arg_count_mismatch_raises(self):
with pytest.raises(ValueError, match="expects 1 argument"):
deserialize_args('["hello", "extra"]', "s")
def test_too_few_args_raises(self):
with pytest.raises(ValueError, match="expects 2 argument"):
deserialize_args('["hello"]', "su")
def test_invalid_json_raises(self):
with pytest.raises(ValueError, match="Invalid JSON"):
deserialize_args("{broken", "s")
class TestValidation:
def test_valid_bus_name(self):
validate_bus_name("org.freedesktop.DBus")
validate_bus_name("org.kde.KWin")
def test_invalid_bus_name(self):
with pytest.raises(ValueError, match="Invalid D-Bus"):
validate_bus_name("not-a-bus-name")
with pytest.raises(ValueError, match="Invalid D-Bus"):
validate_bus_name("")
def test_valid_object_path(self):
validate_object_path("/")
validate_object_path("/org/freedesktop/DBus")
def test_invalid_object_path(self):
with pytest.raises(ValueError, match="Invalid D-Bus object path"):
validate_object_path("not/a/path")
with pytest.raises(ValueError, match="Invalid D-Bus object path"):
validate_object_path("")