Harden all tool/resource paths per Hamilton review
- Add asyncio.wait_for timeout (30s) to all D-Bus calls
- Add asyncio.Lock to BusManager.get_bus to prevent race conditions
- Convert recursive tree walk to bounded BFS with visited set, max depth (20), max nodes (500)
- Add input validation for D-Bus names and object paths at tool boundary
- Standardize error handling (RuntimeError, TimeoutError) across all tools
- Catch JSONDecodeError in deserialize_args and set_property
- Check arg count matches signature in deserialize_args
- Add explicit variant wrapping via {"signature": "ai", "value": [1,2,3]}
- Narrow resource exception handling from Exception to (RuntimeError, OSError)
- Guard disconnect_all per-bus to avoid partial cleanup on error
- Audit log system bus calls and all set_property calls to stderr
- Consolidate _get_mgr into get_mgr in _bus.py
- Sort MPRIS players for deterministic auto-discovery
- Fix progress reporting (pass None as total when unknown)
- Add 7 new tests for validation, arg count, and JSON error paths (32 total)
This commit is contained in:
parent
4d7b73f6ee
commit
fa41ffbf80
@ -1,11 +1,39 @@
|
|||||||
"""Bus connection management and D-Bus type serialization."""
|
"""Bus connection management and D-Bus type serialization."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from dbus_fast import BusType, Message, MessageType
|
from dbus_fast import BusType, Message, MessageType
|
||||||
from dbus_fast.aio import MessageBus
|
from dbus_fast.aio import MessageBus
|
||||||
from dbus_fast.unpack import unpack_variants
|
from dbus_fast.unpack import unpack_variants
|
||||||
|
from fastmcp import Context
|
||||||
|
|
||||||
|
DBUS_CALL_TIMEOUT = float(30) # seconds; override via env MCDBUS_TIMEOUT
|
||||||
|
|
||||||
|
# D-Bus spec: https://dbus.freedesktop.org/doc/dbus-specification.html
|
||||||
|
_DBUS_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_.]*(\.[A-Za-z_][A-Za-z0-9_]*)+$")
|
||||||
|
_DBUS_PATH_RE = re.compile(r"^/$|^(/[A-Za-z0-9_]+)+$")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_bus_name(name: str, label: str = "service") -> None:
|
||||||
|
"""Validate a D-Bus well-known or interface name."""
|
||||||
|
if not _DBUS_NAME_RE.match(name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid D-Bus {label}: {name!r}. "
|
||||||
|
f"Must match [A-Za-z_][A-Za-z0-9_]*(.[A-Za-z_][A-Za-z0-9_]*)+"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_object_path(path: str) -> None:
|
||||||
|
"""Validate a D-Bus object path."""
|
||||||
|
if not _DBUS_PATH_RE.match(path):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid D-Bus object path: {path!r}. "
|
||||||
|
f"Must be '/' or '/element/element/...' with [A-Za-z0-9_]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BusManager:
|
class BusManager:
|
||||||
@ -13,12 +41,17 @@ class BusManager:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._buses: dict[str, MessageBus] = {}
|
self._buses: dict[str, MessageBus] = {}
|
||||||
|
self._locks: dict[str, asyncio.Lock] = {
|
||||||
|
"session": asyncio.Lock(),
|
||||||
|
"system": asyncio.Lock(),
|
||||||
|
}
|
||||||
|
|
||||||
async def get_bus(self, bus_type: str) -> MessageBus:
|
async def get_bus(self, bus_type: str) -> MessageBus:
|
||||||
"""Get a connected bus by type name ("session" or "system")."""
|
"""Get a connected bus by type name ("session" or "system")."""
|
||||||
if bus_type not in ("session", "system"):
|
if bus_type not in ("session", "system"):
|
||||||
raise ValueError(f"bus_type must be 'session' or 'system', got {bus_type!r}")
|
raise ValueError(f"bus_type must be 'session' or 'system', got {bus_type!r}")
|
||||||
|
|
||||||
|
async with self._locks[bus_type]:
|
||||||
if bus_type in self._buses:
|
if bus_type in self._buses:
|
||||||
bus = self._buses[bus_type]
|
bus = self._buses[bus_type]
|
||||||
if bus.connected:
|
if bus.connected:
|
||||||
@ -32,12 +65,20 @@ class BusManager:
|
|||||||
return bus
|
return bus
|
||||||
|
|
||||||
async def disconnect_all(self) -> None:
|
async def disconnect_all(self) -> None:
|
||||||
for bus in self._buses.values():
|
for name, bus in self._buses.items():
|
||||||
|
try:
|
||||||
if bus.connected:
|
if bus.connected:
|
||||||
bus.disconnect()
|
bus.disconnect()
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"mcdbus: error disconnecting {name} bus: {exc}", file=sys.stderr)
|
||||||
self._buses.clear()
|
self._buses.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def get_mgr(ctx: Context) -> "BusManager":
|
||||||
|
"""Extract the BusManager from a FastMCP tool context."""
|
||||||
|
return ctx.request_context.lifespan_context
|
||||||
|
|
||||||
|
|
||||||
def serialize_variant(value: Any) -> Any:
|
def serialize_variant(value: Any) -> Any:
|
||||||
"""Recursively unwrap dbus-fast Variant objects into JSON-safe Python types."""
|
"""Recursively unwrap dbus-fast Variant objects into JSON-safe Python types."""
|
||||||
unpacked = unpack_variants(value)
|
unpacked = unpack_variants(value)
|
||||||
@ -62,13 +103,19 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]:
|
|||||||
"""Parse JSON args string into a Python list suitable for D-Bus call body.
|
"""Parse JSON args string into a Python list suitable for D-Bus call body.
|
||||||
|
|
||||||
For 'v' (Variant) signatures, wraps values in dbus_fast.signature.Variant.
|
For 'v' (Variant) signatures, wraps values in dbus_fast.signature.Variant.
|
||||||
|
Auto-wrapping infers types for simple scalars; for complex variant values
|
||||||
|
(e.g. array of int), pass the value as {"signature": "ai", "value": [1,2,3]}.
|
||||||
"""
|
"""
|
||||||
from dbus_fast.signature import SignatureTree, Variant
|
from dbus_fast.signature import SignatureTree, Variant
|
||||||
|
|
||||||
if not args_json or args_json.strip() in ("", "[]", "null"):
|
if not args_json or args_json.strip() in ("", "[]", "null"):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
args = json.loads(args_json)
|
args = json.loads(args_json)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(f"Invalid JSON in args: {exc}") from exc
|
||||||
|
|
||||||
if not isinstance(args, list):
|
if not isinstance(args, list):
|
||||||
args = [args]
|
args = [args]
|
||||||
|
|
||||||
@ -76,30 +123,46 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]:
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
tree = SignatureTree(signature)
|
tree = SignatureTree(signature)
|
||||||
|
expected = len(tree.types)
|
||||||
|
if len(args) != expected:
|
||||||
|
raise ValueError(
|
||||||
|
f"Signature '{signature}' expects {expected} argument(s), got {len(args)}"
|
||||||
|
)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for i, sig_type in enumerate(tree.types):
|
for i, sig_type in enumerate(tree.types):
|
||||||
if i >= len(args):
|
|
||||||
break
|
|
||||||
val = args[i]
|
val = args[i]
|
||||||
if sig_type.signature == "v":
|
if sig_type.signature == "v":
|
||||||
# Caller must wrap variant values; we auto-wrap strings as "s"
|
val = _auto_wrap_variant(val, Variant)
|
||||||
if isinstance(val, str):
|
|
||||||
val = Variant("s", val)
|
|
||||||
elif isinstance(val, bool):
|
|
||||||
val = Variant("b", val)
|
|
||||||
elif isinstance(val, int):
|
|
||||||
val = Variant("i", val)
|
|
||||||
elif isinstance(val, float):
|
|
||||||
val = Variant("d", val)
|
|
||||||
elif isinstance(val, list):
|
|
||||||
val = Variant("as", val)
|
|
||||||
elif isinstance(val, dict):
|
|
||||||
val = Variant("a{sv}", val)
|
|
||||||
result.append(val)
|
result.append(val)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_wrap_variant(val: Any, variant_cls: type) -> Any:
|
||||||
|
"""Wrap a Python value in a D-Bus Variant, inferring the type signature.
|
||||||
|
|
||||||
|
For explicit control, pass {"signature": "ai", "value": [1,2,3]}.
|
||||||
|
"""
|
||||||
|
# Explicit variant specification
|
||||||
|
if isinstance(val, dict) and "signature" in val and "value" in val and len(val) == 2:
|
||||||
|
return variant_cls(val["signature"], val["value"])
|
||||||
|
# Auto-infer from Python type (bool must come before int — bool is a subclass of int)
|
||||||
|
if isinstance(val, bool):
|
||||||
|
return variant_cls("b", val)
|
||||||
|
if isinstance(val, str):
|
||||||
|
return variant_cls("s", val)
|
||||||
|
if isinstance(val, int):
|
||||||
|
return variant_cls("i", val)
|
||||||
|
if isinstance(val, float):
|
||||||
|
return variant_cls("d", val)
|
||||||
|
if isinstance(val, list):
|
||||||
|
return variant_cls("as", val)
|
||||||
|
if isinstance(val, dict):
|
||||||
|
return variant_cls("a{sv}", val)
|
||||||
|
return variant_cls("s", str(val))
|
||||||
|
|
||||||
|
|
||||||
async def call_bus_method(
|
async def call_bus_method(
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
destination: str,
|
destination: str,
|
||||||
@ -108,6 +171,7 @@ async def call_bus_method(
|
|||||||
member: str,
|
member: str,
|
||||||
signature: str = "",
|
signature: str = "",
|
||||||
body: list[Any] | None = None,
|
body: list[Any] | None = None,
|
||||||
|
timeout: float = DBUS_CALL_TIMEOUT,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Send a D-Bus method call and return the unpacked response body."""
|
"""Send a D-Bus method call and return the unpacked response body."""
|
||||||
msg = Message(
|
msg = Message(
|
||||||
@ -118,7 +182,13 @@ async def call_bus_method(
|
|||||||
signature=signature,
|
signature=signature,
|
||||||
body=body or [],
|
body=body or [],
|
||||||
)
|
)
|
||||||
reply = await bus.call(msg)
|
try:
|
||||||
|
reply = await asyncio.wait_for(bus.call(msg), timeout=timeout)
|
||||||
|
except TimeoutError:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"D-Bus call timed out after {timeout}s: "
|
||||||
|
f"{destination} {interface}.{member} at {path}"
|
||||||
|
) from None
|
||||||
if reply.message_type == MessageType.ERROR:
|
if reply.message_type == MessageType.ERROR:
|
||||||
raise RuntimeError(f"D-Bus error: {reply.error_name}: {reply.body}")
|
raise RuntimeError(f"D-Bus error: {reply.error_name}: {reply.body}")
|
||||||
if reply.body:
|
if reply.body:
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
"""Discovery tools — list services, introspect objects, walk object trees."""
|
"""Discovery tools — list services, introspect objects, walk object trees."""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
from fastmcp import Context
|
from fastmcp import Context
|
||||||
|
|
||||||
from mcdbus._bus import BusManager, call_bus_method
|
from mcdbus._bus import call_bus_method, get_mgr, validate_bus_name, validate_object_path
|
||||||
from mcdbus._state import mcp
|
from mcdbus._state import mcp
|
||||||
|
|
||||||
|
MAX_TREE_DEPTH = 20
|
||||||
def _get_mgr(ctx: Context) -> BusManager:
|
MAX_TREE_NODES = 500
|
||||||
return ctx.request_context.lifespan_context
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
@ -22,7 +23,7 @@ async def list_services(
|
|||||||
bus: "session" or "system"
|
bus: "session" or "system"
|
||||||
include_unique: include unique connection names like :1.42 (default False)
|
include_unique: include unique connection names like :1.42 (default False)
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
mgr = get_mgr(ctx)
|
||||||
await ctx.info(f"Connecting to {bus} bus...")
|
await ctx.info(f"Connecting to {bus} bus...")
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
@ -61,7 +62,10 @@ async def introspect(
|
|||||||
object_path: object path (e.g. "/org/freedesktop/Notifications")
|
object_path: object path (e.g. "/org/freedesktop/Notifications")
|
||||||
include_standard: include standard D-Bus interfaces (Peer, Introspectable, Properties)
|
include_standard: include standard D-Bus interfaces (Peer, Introspectable, Properties)
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
validate_bus_name(service)
|
||||||
|
validate_object_path(object_path)
|
||||||
|
|
||||||
|
mgr = get_mgr(ctx)
|
||||||
await ctx.report_progress(0, 4)
|
await ctx.report_progress(0, 4)
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
@ -102,7 +106,9 @@ async def introspect(
|
|||||||
lines.append("**Properties:**")
|
lines.append("**Properties:**")
|
||||||
for prop in iface.properties:
|
for prop in iface.properties:
|
||||||
access = getattr(prop.access, "name", str(prop.access)).lower()
|
access = getattr(prop.access, "name", str(prop.access)).lower()
|
||||||
lines.append(f"- `{prop.name}`: `{prop.signature}` ({access})")
|
lines.append(
|
||||||
|
f"- `{prop.name}`: `{prop.signature}` ({access})"
|
||||||
|
)
|
||||||
lines.append("")
|
lines.append("")
|
||||||
|
|
||||||
if iface.signals:
|
if iface.signals:
|
||||||
@ -132,40 +138,63 @@ async def list_objects(
|
|||||||
service: str,
|
service: str,
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
root_path: str = "/",
|
root_path: str = "/",
|
||||||
|
max_depth: int = MAX_TREE_DEPTH,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Recursively walk the D-Bus object tree for a service.
|
"""Recursively walk the D-Bus object tree for a service (BFS, bounded).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bus: "session" or "system"
|
bus: "session" or "system"
|
||||||
service: service name
|
service: service name
|
||||||
root_path: starting path (default "/")
|
root_path: starting path (default "/")
|
||||||
|
max_depth: maximum tree depth to walk (default 20)
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
validate_bus_name(service)
|
||||||
|
validate_object_path(root_path)
|
||||||
|
|
||||||
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
paths: list[tuple[str, list[str]]] = []
|
paths: list[tuple[str, list[str]]] = []
|
||||||
|
visited: set[str] = set()
|
||||||
|
queue: deque[tuple[str, int]] = deque([(root_path, 0)])
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
path, depth = queue.popleft()
|
||||||
|
|
||||||
|
if path in visited:
|
||||||
|
continue
|
||||||
|
visited.add(path)
|
||||||
|
|
||||||
|
if len(paths) >= MAX_TREE_NODES:
|
||||||
|
truncated = True
|
||||||
|
break
|
||||||
|
|
||||||
async def _walk(path: str) -> None:
|
|
||||||
try:
|
try:
|
||||||
node = await connection.introspect(service, path)
|
node = await connection.introspect(service, path)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
await ctx.warning(f"Could not introspect {path}: {exc}")
|
await ctx.warning(f"Could not introspect {path}: {exc}")
|
||||||
return
|
continue
|
||||||
|
|
||||||
iface_names = [i.name for i in node.interfaces]
|
iface_names = [i.name for i in node.interfaces]
|
||||||
paths.append((path, iface_names))
|
paths.append((path, iface_names))
|
||||||
await ctx.report_progress(len(paths), len(paths))
|
await ctx.report_progress(len(paths), None)
|
||||||
|
|
||||||
|
if depth < max_depth:
|
||||||
for child in node.nodes:
|
for child in node.nodes:
|
||||||
child_path = path.rstrip("/") + "/" + child.name
|
child_path = path.rstrip("/") + "/" + child.name
|
||||||
await _walk(child_path)
|
queue.append((child_path, depth + 1))
|
||||||
|
|
||||||
await _walk(root_path)
|
header = f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects"
|
||||||
|
if truncated:
|
||||||
|
header += f" (truncated at {MAX_TREE_NODES})"
|
||||||
|
lines = [header + "\n"]
|
||||||
|
|
||||||
lines = [f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects\n"]
|
|
||||||
for path, ifaces in sorted(paths):
|
for path, ifaces in sorted(paths):
|
||||||
std_prefix = "org.freedesktop.DBus."
|
std_prefix = "org.freedesktop.DBus."
|
||||||
iface_str = ", ".join(f"`{i}`" for i in ifaces if not i.startswith(std_prefix))
|
iface_str = ", ".join(
|
||||||
|
f"`{i}`" for i in ifaces if not i.startswith(std_prefix)
|
||||||
|
)
|
||||||
if iface_str:
|
if iface_str:
|
||||||
lines.append(f"- `{path}` — {iface_str}")
|
lines.append(f"- `{path}` — {iface_str}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,18 +1,21 @@
|
|||||||
"""Interaction tools — method calls, property get/set."""
|
"""Interaction tools — method calls, property get/set."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
from dbus_fast.signature import Variant
|
from dbus_fast.signature import Variant
|
||||||
from fastmcp import Context
|
from fastmcp import Context
|
||||||
|
|
||||||
from mcdbus._bus import BusManager, call_bus_method, deserialize_args
|
from mcdbus._bus import (
|
||||||
|
call_bus_method,
|
||||||
|
deserialize_args,
|
||||||
|
get_mgr,
|
||||||
|
validate_bus_name,
|
||||||
|
validate_object_path,
|
||||||
|
)
|
||||||
from mcdbus._state import mcp
|
from mcdbus._state import mcp
|
||||||
|
|
||||||
|
|
||||||
def _get_mgr(ctx: Context) -> BusManager:
|
|
||||||
return ctx.request_context.lifespan_context
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def call_method(
|
async def call_method(
|
||||||
bus: str,
|
bus: str,
|
||||||
@ -35,10 +38,26 @@ async def call_method(
|
|||||||
args: JSON array of arguments (default "[]")
|
args: JSON array of arguments (default "[]")
|
||||||
signature: D-Bus type signature for args (e.g. "su"); leave empty for no-arg methods
|
signature: D-Bus type signature for args (e.g. "su"); leave empty for no-arg methods
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
validate_bus_name(service)
|
||||||
|
validate_object_path(object_path)
|
||||||
|
validate_bus_name(interface, label="interface")
|
||||||
|
|
||||||
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
|
try:
|
||||||
body = deserialize_args(args, signature)
|
body = deserialize_args(args, signature)
|
||||||
|
except (ValueError, json.JSONDecodeError) as exc:
|
||||||
|
return f"Error parsing arguments: {exc}"
|
||||||
|
|
||||||
|
# Log system bus writes to stderr for audit trail
|
||||||
|
if bus == "system":
|
||||||
|
print(
|
||||||
|
f"mcdbus: system bus call {service} {interface}.{method} "
|
||||||
|
f"path={object_path} sig={signature!r}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
await ctx.info(f"Calling {interface}.{method} on {service}...")
|
await ctx.info(f"Calling {interface}.{method} on {service}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -51,7 +70,7 @@ async def call_method(
|
|||||||
signature=signature,
|
signature=signature,
|
||||||
body=body,
|
body=body,
|
||||||
)
|
)
|
||||||
except RuntimeError as exc:
|
except (RuntimeError, TimeoutError) as exc:
|
||||||
await ctx.error(str(exc))
|
await ctx.error(str(exc))
|
||||||
return f"Error: {exc}"
|
return f"Error: {exc}"
|
||||||
|
|
||||||
@ -83,9 +102,14 @@ async def get_property(
|
|||||||
interface: interface that owns the property
|
interface: interface that owns the property
|
||||||
property_name: property name to read
|
property_name: property name to read
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
validate_bus_name(service)
|
||||||
|
validate_object_path(object_path)
|
||||||
|
validate_bus_name(interface, label="interface")
|
||||||
|
|
||||||
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
|
try:
|
||||||
result = await call_bus_method(
|
result = await call_bus_method(
|
||||||
connection,
|
connection,
|
||||||
destination=service,
|
destination=service,
|
||||||
@ -95,6 +119,8 @@ async def get_property(
|
|||||||
signature="ss",
|
signature="ss",
|
||||||
body=[interface, property_name],
|
body=[interface, property_name],
|
||||||
)
|
)
|
||||||
|
except (RuntimeError, TimeoutError) as exc:
|
||||||
|
return f"Error reading {interface}.{property_name}: {exc}"
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
value = result[0]
|
value = result[0]
|
||||||
@ -124,12 +150,28 @@ async def set_property(
|
|||||||
value: JSON-encoded value to set
|
value: JSON-encoded value to set
|
||||||
signature: D-Bus type signature of the property value (e.g. "b" for boolean)
|
signature: D-Bus type signature of the property value (e.g. "b" for boolean)
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
validate_bus_name(service)
|
||||||
|
validate_object_path(object_path)
|
||||||
|
validate_bus_name(interface, label="interface")
|
||||||
|
|
||||||
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
|
try:
|
||||||
parsed_value = json.loads(value)
|
parsed_value = json.loads(value)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
return f"Error parsing value JSON: {exc}"
|
||||||
|
|
||||||
variant = Variant(signature, parsed_value)
|
variant = Variant(signature, parsed_value)
|
||||||
|
|
||||||
|
# Audit log for all set_property calls
|
||||||
|
print(
|
||||||
|
f"mcdbus: set_property {bus} {service} {interface}.{property_name} "
|
||||||
|
f"= {value} (sig={signature!r})",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
await call_bus_method(
|
await call_bus_method(
|
||||||
connection,
|
connection,
|
||||||
destination=service,
|
destination=service,
|
||||||
@ -139,6 +181,8 @@ async def set_property(
|
|||||||
signature="ssv",
|
signature="ssv",
|
||||||
body=[interface, property_name, variant],
|
body=[interface, property_name, variant],
|
||||||
)
|
)
|
||||||
|
except (RuntimeError, TimeoutError) as exc:
|
||||||
|
return f"Error setting {interface}.{property_name}: {exc}"
|
||||||
|
|
||||||
return f"Set {interface}.{property_name} = {value}"
|
return f"Set {interface}.{property_name} = {value}"
|
||||||
|
|
||||||
@ -159,9 +203,14 @@ async def get_all_properties(
|
|||||||
object_path: object path
|
object_path: object path
|
||||||
interface: interface to read properties from
|
interface: interface to read properties from
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
validate_bus_name(service)
|
||||||
|
validate_object_path(object_path)
|
||||||
|
validate_bus_name(interface, label="interface")
|
||||||
|
|
||||||
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
|
try:
|
||||||
result = await call_bus_method(
|
result = await call_bus_method(
|
||||||
connection,
|
connection,
|
||||||
destination=service,
|
destination=service,
|
||||||
@ -171,6 +220,8 @@ async def get_all_properties(
|
|||||||
signature="s",
|
signature="s",
|
||||||
body=[interface],
|
body=[interface],
|
||||||
)
|
)
|
||||||
|
except (RuntimeError, TimeoutError) as exc:
|
||||||
|
return f"Error reading properties on {interface}: {exc}"
|
||||||
|
|
||||||
if not result or not result[0]:
|
if not result or not result[0]:
|
||||||
return "No properties found."
|
return "No properties found."
|
||||||
|
|||||||
@ -1,15 +1,17 @@
|
|||||||
"""Dynamic MCP resources for browsing D-Bus services."""
|
"""Dynamic MCP resources for browsing D-Bus services."""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from mcdbus._bus import BusManager, call_bus_method
|
from mcdbus._bus import BusManager, call_bus_method
|
||||||
from mcdbus._state import mcp
|
from mcdbus._state import mcp
|
||||||
|
|
||||||
|
MAX_RESOURCE_NODES = 200
|
||||||
|
|
||||||
|
|
||||||
@mcp.resource("dbus://{bus}/services")
|
@mcp.resource("dbus://{bus}/services")
|
||||||
async def bus_services(bus: str) -> str:
|
async def bus_services(bus: str) -> str:
|
||||||
"""Live list of well-known service names on a D-Bus bus."""
|
"""Live list of well-known service names on a D-Bus bus."""
|
||||||
# Resources don't get Context — we need a fresh connection
|
|
||||||
mgr = BusManager()
|
mgr = BusManager()
|
||||||
try:
|
try:
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
@ -28,23 +30,31 @@ async def bus_services(bus: str) -> str:
|
|||||||
|
|
||||||
@mcp.resource("dbus://{bus}/{service}/objects")
|
@mcp.resource("dbus://{bus}/{service}/objects")
|
||||||
async def service_objects(bus: str, service: str) -> str:
|
async def service_objects(bus: str, service: str) -> str:
|
||||||
"""Object tree for a D-Bus service."""
|
"""Object tree for a D-Bus service (bounded BFS walk)."""
|
||||||
mgr = BusManager()
|
mgr = BusManager()
|
||||||
try:
|
try:
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
paths: list[str] = []
|
paths: list[str] = []
|
||||||
|
visited: set[str] = set()
|
||||||
|
queue: deque[tuple[str, int]] = deque([("/", 0)])
|
||||||
|
|
||||||
|
while queue and len(paths) < MAX_RESOURCE_NODES:
|
||||||
|
path, depth = queue.popleft()
|
||||||
|
if path in visited:
|
||||||
|
continue
|
||||||
|
visited.add(path)
|
||||||
|
|
||||||
async def _walk(path: str) -> None:
|
|
||||||
try:
|
try:
|
||||||
node = await connection.introspect(service, path)
|
node = await connection.introspect(service, path)
|
||||||
|
except (RuntimeError, OSError):
|
||||||
|
continue
|
||||||
|
|
||||||
paths.append(path)
|
paths.append(path)
|
||||||
|
if depth < 15:
|
||||||
for child in node.nodes:
|
for child in node.nodes:
|
||||||
child_path = path.rstrip("/") + "/" + child.name
|
child_path = path.rstrip("/") + "/" + child.name
|
||||||
await _walk(child_path)
|
queue.append((child_path, depth + 1))
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
await _walk("/")
|
|
||||||
return "\n".join(sorted(paths))
|
return "\n".join(sorted(paths))
|
||||||
finally:
|
finally:
|
||||||
await mgr.disconnect_all()
|
await mgr.disconnect_all()
|
||||||
@ -53,7 +63,6 @@ async def service_objects(bus: str, service: str) -> str:
|
|||||||
@mcp.resource("dbus://{bus}/{service}/{path}/interfaces")
|
@mcp.resource("dbus://{bus}/{service}/{path}/interfaces")
|
||||||
async def object_interfaces(bus: str, service: str, path: str) -> str:
|
async def object_interfaces(bus: str, service: str, path: str) -> str:
|
||||||
"""Interfaces available at a D-Bus object path."""
|
"""Interfaces available at a D-Bus object path."""
|
||||||
# path comes URL-encoded (slashes → %2F), decode it
|
|
||||||
decoded_path = unquote(path)
|
decoded_path = unquote(path)
|
||||||
if not decoded_path.startswith("/"):
|
if not decoded_path.startswith("/"):
|
||||||
decoded_path = "/" + decoded_path
|
decoded_path = "/" + decoded_path
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
"""High-level convenience tools for common D-Bus operations."""
|
"""High-level convenience tools for common D-Bus operations."""
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
from fastmcp import Context
|
from fastmcp import Context
|
||||||
|
|
||||||
from mcdbus._bus import BusManager, call_bus_method
|
from mcdbus._bus import call_bus_method, get_mgr
|
||||||
from mcdbus._state import mcp
|
from mcdbus._state import mcp
|
||||||
|
|
||||||
|
|
||||||
def _get_mgr(ctx: Context) -> BusManager:
|
|
||||||
return ctx.request_context.lifespan_context
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def send_notification(
|
async def send_notification(
|
||||||
summary: str,
|
summary: str,
|
||||||
@ -26,9 +24,10 @@ async def send_notification(
|
|||||||
icon: icon name or path (empty for default)
|
icon: icon name or path (empty for default)
|
||||||
timeout: display duration in milliseconds (default 5000)
|
timeout: display duration in milliseconds (default 5000)
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus("session")
|
connection = await mgr.get_bus("session")
|
||||||
|
|
||||||
|
try:
|
||||||
result = await call_bus_method(
|
result = await call_bus_method(
|
||||||
connection,
|
connection,
|
||||||
destination="org.freedesktop.Notifications",
|
destination="org.freedesktop.Notifications",
|
||||||
@ -47,6 +46,8 @@ async def send_notification(
|
|||||||
timeout, # expire_timeout
|
timeout, # expire_timeout
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
except (RuntimeError, TimeoutError) as exc:
|
||||||
|
return f"Error sending notification: {exc}"
|
||||||
|
|
||||||
nid = result[0] if result else "unknown"
|
nid = result[0] if result else "unknown"
|
||||||
return f"Notification sent (id: {nid})"
|
return f"Notification sent (id: {nid})"
|
||||||
@ -64,9 +65,10 @@ async def list_systemd_units(
|
|||||||
bus: "session" (user units) or "system" (system units, default)
|
bus: "session" (user units) or "system" (system units, default)
|
||||||
pattern: optional glob filter (e.g. "docker*", "*.service")
|
pattern: optional glob filter (e.g. "docker*", "*.service")
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus(bus)
|
connection = await mgr.get_bus(bus)
|
||||||
|
|
||||||
|
try:
|
||||||
result = await call_bus_method(
|
result = await call_bus_method(
|
||||||
connection,
|
connection,
|
||||||
destination="org.freedesktop.systemd1",
|
destination="org.freedesktop.systemd1",
|
||||||
@ -74,6 +76,8 @@ async def list_systemd_units(
|
|||||||
interface="org.freedesktop.systemd1.Manager",
|
interface="org.freedesktop.systemd1.Manager",
|
||||||
member="ListUnits",
|
member="ListUnits",
|
||||||
)
|
)
|
||||||
|
except (RuntimeError, TimeoutError) as exc:
|
||||||
|
return f"Error listing units: {exc}"
|
||||||
|
|
||||||
if not result or not result[0]:
|
if not result or not result[0]:
|
||||||
return "No units found."
|
return "No units found."
|
||||||
@ -81,7 +85,6 @@ async def list_systemd_units(
|
|||||||
units = result[0]
|
units = result[0]
|
||||||
|
|
||||||
# Each unit is (name, description, load_state, active_state, sub_state, ...)
|
# Each unit is (name, description, load_state, active_state, sub_state, ...)
|
||||||
import fnmatch
|
|
||||||
if pattern:
|
if pattern:
|
||||||
units = [u for u in units if fnmatch.fnmatch(u[0], pattern)]
|
units = [u for u in units if fnmatch.fnmatch(u[0], pattern)]
|
||||||
|
|
||||||
@ -107,11 +110,10 @@ async def media_player_control(
|
|||||||
action: one of "play", "pause", "next", "previous", "stop", "play-pause"
|
action: one of "play", "pause", "next", "previous", "stop", "play-pause"
|
||||||
player: MPRIS player name (auto-discovers first player if empty)
|
player: MPRIS player name (auto-discovers first player if empty)
|
||||||
"""
|
"""
|
||||||
mgr = _get_mgr(ctx)
|
mgr = get_mgr(ctx)
|
||||||
connection = await mgr.get_bus("session")
|
connection = await mgr.get_bus("session")
|
||||||
|
|
||||||
if not player:
|
if not player:
|
||||||
# Discover MPRIS players
|
|
||||||
result = await call_bus_method(
|
result = await call_bus_method(
|
||||||
connection,
|
connection,
|
||||||
destination="org.freedesktop.DBus",
|
destination="org.freedesktop.DBus",
|
||||||
@ -120,12 +122,18 @@ async def media_player_control(
|
|||||||
member="ListNames",
|
member="ListNames",
|
||||||
)
|
)
|
||||||
names = result[0] if result else []
|
names = result[0] if result else []
|
||||||
mpris_services = [n for n in names if n.startswith("org.mpris.MediaPlayer2.")]
|
mpris_services = sorted(
|
||||||
|
n for n in names if n.startswith("org.mpris.MediaPlayer2.")
|
||||||
|
)
|
||||||
|
|
||||||
if not mpris_services:
|
if not mpris_services:
|
||||||
return "No MPRIS media players found on session bus."
|
return "No MPRIS media players found on session bus."
|
||||||
player = mpris_services[0]
|
player = mpris_services[0]
|
||||||
await ctx.info(f"Auto-discovered player: {player}")
|
others = (
|
||||||
|
f"\nOther players: {', '.join(mpris_services[1:])}"
|
||||||
|
if len(mpris_services) > 1 else ""
|
||||||
|
)
|
||||||
|
await ctx.info(f"Auto-discovered player: {player}{others}")
|
||||||
|
|
||||||
action_map = {
|
action_map = {
|
||||||
"play": "Play",
|
"play": "Play",
|
||||||
@ -140,6 +148,7 @@ async def media_player_control(
|
|||||||
if not dbus_method:
|
if not dbus_method:
|
||||||
return f"Unknown action: {action}. Use: {', '.join(action_map.keys())}"
|
return f"Unknown action: {action}. Use: {', '.join(action_map.keys())}"
|
||||||
|
|
||||||
|
try:
|
||||||
await call_bus_method(
|
await call_bus_method(
|
||||||
connection,
|
connection,
|
||||||
destination=player,
|
destination=player,
|
||||||
@ -147,6 +156,8 @@ async def media_player_control(
|
|||||||
interface="org.mpris.MediaPlayer2.Player",
|
interface="org.mpris.MediaPlayer2.Player",
|
||||||
member=dbus_method,
|
member=dbus_method,
|
||||||
)
|
)
|
||||||
|
except (RuntimeError, TimeoutError) as exc:
|
||||||
|
return f"Error controlling player: {exc}"
|
||||||
|
|
||||||
# Read current playback status
|
# Read current playback status
|
||||||
try:
|
try:
|
||||||
@ -160,7 +171,7 @@ async def media_player_control(
|
|||||||
body=["org.mpris.MediaPlayer2.Player", "PlaybackStatus"],
|
body=["org.mpris.MediaPlayer2.Player", "PlaybackStatus"],
|
||||||
)
|
)
|
||||||
status = status_result[0] if status_result else "unknown"
|
status = status_result[0] if status_result else "unknown"
|
||||||
except Exception:
|
except (RuntimeError, TimeoutError):
|
||||||
status = "unknown"
|
status = "unknown"
|
||||||
|
|
||||||
# Try to get current track metadata
|
# Try to get current track metadata
|
||||||
@ -179,7 +190,7 @@ async def media_player_control(
|
|||||||
artist = metadata.get("xesam:artist", ["Unknown"])
|
artist = metadata.get("xesam:artist", ["Unknown"])
|
||||||
if isinstance(artist, list):
|
if isinstance(artist, list):
|
||||||
artist = ", ".join(artist)
|
artist = ", ".join(artist)
|
||||||
except Exception:
|
except (RuntimeError, TimeoutError):
|
||||||
title, artist = "Unknown", "Unknown"
|
title, artist = "Unknown", "Unknown"
|
||||||
|
|
||||||
return f"Player: {player}\nAction: {action}\nStatus: {status}\nNow: {artist} — {title}"
|
return f"Player: {player}\nAction: {action}\nStatus: {status}\nNow: {artist} — {title}"
|
||||||
|
|||||||
@ -2,7 +2,13 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from mcdbus._bus import BusManager, deserialize_args, serialize_variant
|
from mcdbus._bus import (
|
||||||
|
BusManager,
|
||||||
|
deserialize_args,
|
||||||
|
serialize_variant,
|
||||||
|
validate_bus_name,
|
||||||
|
validate_object_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestBusManager:
|
class TestBusManager:
|
||||||
@ -71,3 +77,37 @@ class TestDeserializeArgs:
|
|||||||
def test_single_value_wraps_in_list(self):
|
def test_single_value_wraps_in_list(self):
|
||||||
result = deserialize_args('"hello"', "s")
|
result = deserialize_args('"hello"', "s")
|
||||||
assert result == ["hello"]
|
assert result == ["hello"]
|
||||||
|
|
||||||
|
def test_arg_count_mismatch_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="expects 1 argument"):
|
||||||
|
deserialize_args('["hello", "extra"]', "s")
|
||||||
|
|
||||||
|
def test_too_few_args_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="expects 2 argument"):
|
||||||
|
deserialize_args('["hello"]', "su")
|
||||||
|
|
||||||
|
def test_invalid_json_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid JSON"):
|
||||||
|
deserialize_args("{broken", "s")
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidation:
|
||||||
|
def test_valid_bus_name(self):
|
||||||
|
validate_bus_name("org.freedesktop.DBus")
|
||||||
|
validate_bus_name("org.kde.KWin")
|
||||||
|
|
||||||
|
def test_invalid_bus_name(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid D-Bus"):
|
||||||
|
validate_bus_name("not-a-bus-name")
|
||||||
|
with pytest.raises(ValueError, match="Invalid D-Bus"):
|
||||||
|
validate_bus_name("")
|
||||||
|
|
||||||
|
def test_valid_object_path(self):
|
||||||
|
validate_object_path("/")
|
||||||
|
validate_object_path("/org/freedesktop/DBus")
|
||||||
|
|
||||||
|
def test_invalid_object_path(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid D-Bus object path"):
|
||||||
|
validate_object_path("not/a/path")
|
||||||
|
with pytest.raises(ValueError, match="Invalid D-Bus object path"):
|
||||||
|
validate_object_path("")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user