Harden all tool/resource paths per Hamilton review
- Add asyncio.wait_for timeout (30s) to all D-Bus calls
- Add asyncio.Lock to BusManager.get_bus to prevent race conditions
- Convert recursive tree walk to bounded BFS with visited set, max depth (20), max nodes (500)
- Add input validation for D-Bus names and object paths at tool boundary
- Standardize error handling (RuntimeError, TimeoutError) across all tools
- Catch JSONDecodeError in deserialize_args and set_property
- Check arg count matches signature in deserialize_args
- Add explicit variant wrapping via {"signature": "ai", "value": [1,2,3]}
- Narrow resource exception handling from Exception to (RuntimeError, OSError)
- Guard disconnect_all per-bus to avoid partial cleanup on error
- Audit log system bus calls and all set_property calls to stderr
- Consolidate _get_mgr into get_mgr in _bus.py
- Sort MPRIS players for deterministic auto-discovery
- Fix progress reporting (pass None as total when unknown)
- Add 7 new tests for validation, arg count, and JSON error paths (32 total)
This commit is contained in:
parent
4d7b73f6ee
commit
fa41ffbf80
@ -1,11 +1,39 @@
|
||||
"""Bus connection management and D-Bus type serialization."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from dbus_fast import BusType, Message, MessageType
|
||||
from dbus_fast.aio import MessageBus
|
||||
from dbus_fast.unpack import unpack_variants
|
||||
from fastmcp import Context
|
||||
|
||||
DBUS_CALL_TIMEOUT = float(30) # seconds; override via env MCDBUS_TIMEOUT
|
||||
|
||||
# D-Bus spec: https://dbus.freedesktop.org/doc/dbus-specification.html
|
||||
_DBUS_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_.]*(\.[A-Za-z_][A-Za-z0-9_]*)+$")
|
||||
_DBUS_PATH_RE = re.compile(r"^/$|^(/[A-Za-z0-9_]+)+$")
|
||||
|
||||
|
||||
def validate_bus_name(name: str, label: str = "service") -> None:
|
||||
"""Validate a D-Bus well-known or interface name."""
|
||||
if not _DBUS_NAME_RE.match(name):
|
||||
raise ValueError(
|
||||
f"Invalid D-Bus {label}: {name!r}. "
|
||||
f"Must match [A-Za-z_][A-Za-z0-9_]*(.[A-Za-z_][A-Za-z0-9_]*)+"
|
||||
)
|
||||
|
||||
|
||||
def validate_object_path(path: str) -> None:
|
||||
"""Validate a D-Bus object path."""
|
||||
if not _DBUS_PATH_RE.match(path):
|
||||
raise ValueError(
|
||||
f"Invalid D-Bus object path: {path!r}. "
|
||||
f"Must be '/' or '/element/element/...' with [A-Za-z0-9_]"
|
||||
)
|
||||
|
||||
|
||||
class BusManager:
|
||||
@ -13,31 +41,44 @@ class BusManager:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buses: dict[str, MessageBus] = {}
|
||||
self._locks: dict[str, asyncio.Lock] = {
|
||||
"session": asyncio.Lock(),
|
||||
"system": asyncio.Lock(),
|
||||
}
|
||||
|
||||
async def get_bus(self, bus_type: str) -> MessageBus:
|
||||
"""Get a connected bus by type name ("session" or "system")."""
|
||||
if bus_type not in ("session", "system"):
|
||||
raise ValueError(f"bus_type must be 'session' or 'system', got {bus_type!r}")
|
||||
|
||||
if bus_type in self._buses:
|
||||
bus = self._buses[bus_type]
|
||||
if bus.connected:
|
||||
return bus
|
||||
# stale — drop and reconnect
|
||||
del self._buses[bus_type]
|
||||
async with self._locks[bus_type]:
|
||||
if bus_type in self._buses:
|
||||
bus = self._buses[bus_type]
|
||||
if bus.connected:
|
||||
return bus
|
||||
# stale — drop and reconnect
|
||||
del self._buses[bus_type]
|
||||
|
||||
bt = BusType.SESSION if bus_type == "session" else BusType.SYSTEM
|
||||
bus = await MessageBus(bus_type=bt).connect()
|
||||
self._buses[bus_type] = bus
|
||||
return bus
|
||||
bt = BusType.SESSION if bus_type == "session" else BusType.SYSTEM
|
||||
bus = await MessageBus(bus_type=bt).connect()
|
||||
self._buses[bus_type] = bus
|
||||
return bus
|
||||
|
||||
async def disconnect_all(self) -> None:
|
||||
for bus in self._buses.values():
|
||||
if bus.connected:
|
||||
bus.disconnect()
|
||||
for name, bus in self._buses.items():
|
||||
try:
|
||||
if bus.connected:
|
||||
bus.disconnect()
|
||||
except Exception as exc:
|
||||
print(f"mcdbus: error disconnecting {name} bus: {exc}", file=sys.stderr)
|
||||
self._buses.clear()
|
||||
|
||||
|
||||
def get_mgr(ctx: Context) -> "BusManager":
|
||||
"""Extract the BusManager from a FastMCP tool context."""
|
||||
return ctx.request_context.lifespan_context
|
||||
|
||||
|
||||
def serialize_variant(value: Any) -> Any:
|
||||
"""Recursively unwrap dbus-fast Variant objects into JSON-safe Python types."""
|
||||
unpacked = unpack_variants(value)
|
||||
@ -62,13 +103,19 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]:
|
||||
"""Parse JSON args string into a Python list suitable for D-Bus call body.
|
||||
|
||||
For 'v' (Variant) signatures, wraps values in dbus_fast.signature.Variant.
|
||||
Auto-wrapping infers types for simple scalars; for complex variant values
|
||||
(e.g. array of int), pass the value as {"signature": "ai", "value": [1,2,3]}.
|
||||
"""
|
||||
from dbus_fast.signature import SignatureTree, Variant
|
||||
|
||||
if not args_json or args_json.strip() in ("", "[]", "null"):
|
||||
return []
|
||||
|
||||
args = json.loads(args_json)
|
||||
try:
|
||||
args = json.loads(args_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"Invalid JSON in args: {exc}") from exc
|
||||
|
||||
if not isinstance(args, list):
|
||||
args = [args]
|
||||
|
||||
@ -76,30 +123,46 @@ def deserialize_args(args_json: str, signature: str) -> list[Any]:
|
||||
return args
|
||||
|
||||
tree = SignatureTree(signature)
|
||||
expected = len(tree.types)
|
||||
if len(args) != expected:
|
||||
raise ValueError(
|
||||
f"Signature '{signature}' expects {expected} argument(s), got {len(args)}"
|
||||
)
|
||||
|
||||
result = []
|
||||
for i, sig_type in enumerate(tree.types):
|
||||
if i >= len(args):
|
||||
break
|
||||
val = args[i]
|
||||
if sig_type.signature == "v":
|
||||
# Caller must wrap variant values; we auto-wrap strings as "s"
|
||||
if isinstance(val, str):
|
||||
val = Variant("s", val)
|
||||
elif isinstance(val, bool):
|
||||
val = Variant("b", val)
|
||||
elif isinstance(val, int):
|
||||
val = Variant("i", val)
|
||||
elif isinstance(val, float):
|
||||
val = Variant("d", val)
|
||||
elif isinstance(val, list):
|
||||
val = Variant("as", val)
|
||||
elif isinstance(val, dict):
|
||||
val = Variant("a{sv}", val)
|
||||
val = _auto_wrap_variant(val, Variant)
|
||||
result.append(val)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _auto_wrap_variant(val: Any, variant_cls: type) -> Any:
|
||||
"""Wrap a Python value in a D-Bus Variant, inferring the type signature.
|
||||
|
||||
For explicit control, pass {"signature": "ai", "value": [1,2,3]}.
|
||||
"""
|
||||
# Explicit variant specification
|
||||
if isinstance(val, dict) and "signature" in val and "value" in val and len(val) == 2:
|
||||
return variant_cls(val["signature"], val["value"])
|
||||
# Auto-infer from Python type (bool must come before int — bool is a subclass of int)
|
||||
if isinstance(val, bool):
|
||||
return variant_cls("b", val)
|
||||
if isinstance(val, str):
|
||||
return variant_cls("s", val)
|
||||
if isinstance(val, int):
|
||||
return variant_cls("i", val)
|
||||
if isinstance(val, float):
|
||||
return variant_cls("d", val)
|
||||
if isinstance(val, list):
|
||||
return variant_cls("as", val)
|
||||
if isinstance(val, dict):
|
||||
return variant_cls("a{sv}", val)
|
||||
return variant_cls("s", str(val))
|
||||
|
||||
|
||||
async def call_bus_method(
|
||||
bus: MessageBus,
|
||||
destination: str,
|
||||
@ -108,6 +171,7 @@ async def call_bus_method(
|
||||
member: str,
|
||||
signature: str = "",
|
||||
body: list[Any] | None = None,
|
||||
timeout: float = DBUS_CALL_TIMEOUT,
|
||||
) -> Any:
|
||||
"""Send a D-Bus method call and return the unpacked response body."""
|
||||
msg = Message(
|
||||
@ -118,7 +182,13 @@ async def call_bus_method(
|
||||
signature=signature,
|
||||
body=body or [],
|
||||
)
|
||||
reply = await bus.call(msg)
|
||||
try:
|
||||
reply = await asyncio.wait_for(bus.call(msg), timeout=timeout)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"D-Bus call timed out after {timeout}s: "
|
||||
f"{destination} {interface}.{member} at {path}"
|
||||
) from None
|
||||
if reply.message_type == MessageType.ERROR:
|
||||
raise RuntimeError(f"D-Bus error: {reply.error_name}: {reply.body}")
|
||||
if reply.body:
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
"""Discovery tools — list services, introspect objects, walk object trees."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
from fastmcp import Context
|
||||
|
||||
from mcdbus._bus import BusManager, call_bus_method
|
||||
from mcdbus._bus import call_bus_method, get_mgr, validate_bus_name, validate_object_path
|
||||
from mcdbus._state import mcp
|
||||
|
||||
|
||||
def _get_mgr(ctx: Context) -> BusManager:
|
||||
return ctx.request_context.lifespan_context
|
||||
MAX_TREE_DEPTH = 20
|
||||
MAX_TREE_NODES = 500
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@ -22,7 +23,7 @@ async def list_services(
|
||||
bus: "session" or "system"
|
||||
include_unique: include unique connection names like :1.42 (default False)
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
mgr = get_mgr(ctx)
|
||||
await ctx.info(f"Connecting to {bus} bus...")
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
@ -61,7 +62,10 @@ async def introspect(
|
||||
object_path: object path (e.g. "/org/freedesktop/Notifications")
|
||||
include_standard: include standard D-Bus interfaces (Peer, Introspectable, Properties)
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
validate_bus_name(service)
|
||||
validate_object_path(object_path)
|
||||
|
||||
mgr = get_mgr(ctx)
|
||||
await ctx.report_progress(0, 4)
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
@ -102,7 +106,9 @@ async def introspect(
|
||||
lines.append("**Properties:**")
|
||||
for prop in iface.properties:
|
||||
access = getattr(prop.access, "name", str(prop.access)).lower()
|
||||
lines.append(f"- `{prop.name}`: `{prop.signature}` ({access})")
|
||||
lines.append(
|
||||
f"- `{prop.name}`: `{prop.signature}` ({access})"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
if iface.signals:
|
||||
@ -132,40 +138,63 @@ async def list_objects(
|
||||
service: str,
|
||||
ctx: Context,
|
||||
root_path: str = "/",
|
||||
max_depth: int = MAX_TREE_DEPTH,
|
||||
) -> str:
|
||||
"""Recursively walk the D-Bus object tree for a service.
|
||||
"""Recursively walk the D-Bus object tree for a service (BFS, bounded).
|
||||
|
||||
Args:
|
||||
bus: "session" or "system"
|
||||
service: service name
|
||||
root_path: starting path (default "/")
|
||||
max_depth: maximum tree depth to walk (default 20)
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
validate_bus_name(service)
|
||||
validate_object_path(root_path)
|
||||
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
paths: list[tuple[str, list[str]]] = []
|
||||
visited: set[str] = set()
|
||||
queue: deque[tuple[str, int]] = deque([(root_path, 0)])
|
||||
truncated = False
|
||||
|
||||
while queue:
|
||||
path, depth = queue.popleft()
|
||||
|
||||
if path in visited:
|
||||
continue
|
||||
visited.add(path)
|
||||
|
||||
if len(paths) >= MAX_TREE_NODES:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
async def _walk(path: str) -> None:
|
||||
try:
|
||||
node = await connection.introspect(service, path)
|
||||
except Exception as exc:
|
||||
await ctx.warning(f"Could not introspect {path}: {exc}")
|
||||
return
|
||||
continue
|
||||
|
||||
iface_names = [i.name for i in node.interfaces]
|
||||
paths.append((path, iface_names))
|
||||
await ctx.report_progress(len(paths), len(paths))
|
||||
await ctx.report_progress(len(paths), None)
|
||||
|
||||
for child in node.nodes:
|
||||
child_path = path.rstrip("/") + "/" + child.name
|
||||
await _walk(child_path)
|
||||
if depth < max_depth:
|
||||
for child in node.nodes:
|
||||
child_path = path.rstrip("/") + "/" + child.name
|
||||
queue.append((child_path, depth + 1))
|
||||
|
||||
await _walk(root_path)
|
||||
header = f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects"
|
||||
if truncated:
|
||||
header += f" (truncated at {MAX_TREE_NODES})"
|
||||
lines = [header + "\n"]
|
||||
|
||||
lines = [f"## Object tree for `{service}` on {bus} bus — {len(paths)} objects\n"]
|
||||
for path, ifaces in sorted(paths):
|
||||
std_prefix = "org.freedesktop.DBus."
|
||||
iface_str = ", ".join(f"`{i}`" for i in ifaces if not i.startswith(std_prefix))
|
||||
iface_str = ", ".join(
|
||||
f"`{i}`" for i in ifaces if not i.startswith(std_prefix)
|
||||
)
|
||||
if iface_str:
|
||||
lines.append(f"- `{path}` — {iface_str}")
|
||||
else:
|
||||
|
||||
@ -1,18 +1,21 @@
|
||||
"""Interaction tools — method calls, property get/set."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from dbus_fast.signature import Variant
|
||||
from fastmcp import Context
|
||||
|
||||
from mcdbus._bus import BusManager, call_bus_method, deserialize_args
|
||||
from mcdbus._bus import (
|
||||
call_bus_method,
|
||||
deserialize_args,
|
||||
get_mgr,
|
||||
validate_bus_name,
|
||||
validate_object_path,
|
||||
)
|
||||
from mcdbus._state import mcp
|
||||
|
||||
|
||||
def _get_mgr(ctx: Context) -> BusManager:
|
||||
return ctx.request_context.lifespan_context
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def call_method(
|
||||
bus: str,
|
||||
@ -35,10 +38,26 @@ async def call_method(
|
||||
args: JSON array of arguments (default "[]")
|
||||
signature: D-Bus type signature for args (e.g. "su"); leave empty for no-arg methods
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
validate_bus_name(service)
|
||||
validate_object_path(object_path)
|
||||
validate_bus_name(interface, label="interface")
|
||||
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
body = deserialize_args(args, signature)
|
||||
try:
|
||||
body = deserialize_args(args, signature)
|
||||
except (ValueError, json.JSONDecodeError) as exc:
|
||||
return f"Error parsing arguments: {exc}"
|
||||
|
||||
# Log system bus writes to stderr for audit trail
|
||||
if bus == "system":
|
||||
print(
|
||||
f"mcdbus: system bus call {service} {interface}.{method} "
|
||||
f"path={object_path} sig={signature!r}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
await ctx.info(f"Calling {interface}.{method} on {service}...")
|
||||
|
||||
try:
|
||||
@ -51,7 +70,7 @@ async def call_method(
|
||||
signature=signature,
|
||||
body=body,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
except (RuntimeError, TimeoutError) as exc:
|
||||
await ctx.error(str(exc))
|
||||
return f"Error: {exc}"
|
||||
|
||||
@ -83,18 +102,25 @@ async def get_property(
|
||||
interface: interface that owns the property
|
||||
property_name: property name to read
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
validate_bus_name(service)
|
||||
validate_object_path(object_path)
|
||||
validate_bus_name(interface, label="interface")
|
||||
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination=service,
|
||||
path=object_path,
|
||||
interface="org.freedesktop.DBus.Properties",
|
||||
member="Get",
|
||||
signature="ss",
|
||||
body=[interface, property_name],
|
||||
)
|
||||
try:
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination=service,
|
||||
path=object_path,
|
||||
interface="org.freedesktop.DBus.Properties",
|
||||
member="Get",
|
||||
signature="ss",
|
||||
body=[interface, property_name],
|
||||
)
|
||||
except (RuntimeError, TimeoutError) as exc:
|
||||
return f"Error reading {interface}.{property_name}: {exc}"
|
||||
|
||||
if result:
|
||||
value = result[0]
|
||||
@ -124,22 +150,40 @@ async def set_property(
|
||||
value: JSON-encoded value to set
|
||||
signature: D-Bus type signature of the property value (e.g. "b" for boolean)
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
validate_bus_name(service)
|
||||
validate_object_path(object_path)
|
||||
validate_bus_name(interface, label="interface")
|
||||
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
parsed_value = json.loads(value)
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except json.JSONDecodeError as exc:
|
||||
return f"Error parsing value JSON: {exc}"
|
||||
|
||||
variant = Variant(signature, parsed_value)
|
||||
|
||||
await call_bus_method(
|
||||
connection,
|
||||
destination=service,
|
||||
path=object_path,
|
||||
interface="org.freedesktop.DBus.Properties",
|
||||
member="Set",
|
||||
signature="ssv",
|
||||
body=[interface, property_name, variant],
|
||||
# Audit log for all set_property calls
|
||||
print(
|
||||
f"mcdbus: set_property {bus} {service} {interface}.{property_name} "
|
||||
f"= {value} (sig={signature!r})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
try:
|
||||
await call_bus_method(
|
||||
connection,
|
||||
destination=service,
|
||||
path=object_path,
|
||||
interface="org.freedesktop.DBus.Properties",
|
||||
member="Set",
|
||||
signature="ssv",
|
||||
body=[interface, property_name, variant],
|
||||
)
|
||||
except (RuntimeError, TimeoutError) as exc:
|
||||
return f"Error setting {interface}.{property_name}: {exc}"
|
||||
|
||||
return f"Set {interface}.{property_name} = {value}"
|
||||
|
||||
|
||||
@ -159,18 +203,25 @@ async def get_all_properties(
|
||||
object_path: object path
|
||||
interface: interface to read properties from
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
validate_bus_name(service)
|
||||
validate_object_path(object_path)
|
||||
validate_bus_name(interface, label="interface")
|
||||
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination=service,
|
||||
path=object_path,
|
||||
interface="org.freedesktop.DBus.Properties",
|
||||
member="GetAll",
|
||||
signature="s",
|
||||
body=[interface],
|
||||
)
|
||||
try:
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination=service,
|
||||
path=object_path,
|
||||
interface="org.freedesktop.DBus.Properties",
|
||||
member="GetAll",
|
||||
signature="s",
|
||||
body=[interface],
|
||||
)
|
||||
except (RuntimeError, TimeoutError) as exc:
|
||||
return f"Error reading properties on {interface}: {exc}"
|
||||
|
||||
if not result or not result[0]:
|
||||
return "No properties found."
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
"""Dynamic MCP resources for browsing D-Bus services."""
|
||||
|
||||
from collections import deque
|
||||
from urllib.parse import unquote
|
||||
|
||||
from mcdbus._bus import BusManager, call_bus_method
|
||||
from mcdbus._state import mcp
|
||||
|
||||
MAX_RESOURCE_NODES = 200
|
||||
|
||||
|
||||
@mcp.resource("dbus://{bus}/services")
|
||||
async def bus_services(bus: str) -> str:
|
||||
"""Live list of well-known service names on a D-Bus bus."""
|
||||
# Resources don't get Context — we need a fresh connection
|
||||
mgr = BusManager()
|
||||
try:
|
||||
connection = await mgr.get_bus(bus)
|
||||
@ -28,23 +30,31 @@ async def bus_services(bus: str) -> str:
|
||||
|
||||
@mcp.resource("dbus://{bus}/{service}/objects")
|
||||
async def service_objects(bus: str, service: str) -> str:
|
||||
"""Object tree for a D-Bus service."""
|
||||
"""Object tree for a D-Bus service (bounded BFS walk)."""
|
||||
mgr = BusManager()
|
||||
try:
|
||||
connection = await mgr.get_bus(bus)
|
||||
paths: list[str] = []
|
||||
visited: set[str] = set()
|
||||
queue: deque[tuple[str, int]] = deque([("/", 0)])
|
||||
|
||||
while queue and len(paths) < MAX_RESOURCE_NODES:
|
||||
path, depth = queue.popleft()
|
||||
if path in visited:
|
||||
continue
|
||||
visited.add(path)
|
||||
|
||||
async def _walk(path: str) -> None:
|
||||
try:
|
||||
node = await connection.introspect(service, path)
|
||||
paths.append(path)
|
||||
except (RuntimeError, OSError):
|
||||
continue
|
||||
|
||||
paths.append(path)
|
||||
if depth < 15:
|
||||
for child in node.nodes:
|
||||
child_path = path.rstrip("/") + "/" + child.name
|
||||
await _walk(child_path)
|
||||
except Exception:
|
||||
pass
|
||||
queue.append((child_path, depth + 1))
|
||||
|
||||
await _walk("/")
|
||||
return "\n".join(sorted(paths))
|
||||
finally:
|
||||
await mgr.disconnect_all()
|
||||
@ -53,7 +63,6 @@ async def service_objects(bus: str, service: str) -> str:
|
||||
@mcp.resource("dbus://{bus}/{service}/{path}/interfaces")
|
||||
async def object_interfaces(bus: str, service: str, path: str) -> str:
|
||||
"""Interfaces available at a D-Bus object path."""
|
||||
# path comes URL-encoded (slashes → %2F), decode it
|
||||
decoded_path = unquote(path)
|
||||
if not decoded_path.startswith("/"):
|
||||
decoded_path = "/" + decoded_path
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
"""High-level convenience tools for common D-Bus operations."""
|
||||
|
||||
import fnmatch
|
||||
|
||||
from fastmcp import Context
|
||||
|
||||
from mcdbus._bus import BusManager, call_bus_method
|
||||
from mcdbus._bus import call_bus_method, get_mgr
|
||||
from mcdbus._state import mcp
|
||||
|
||||
|
||||
def _get_mgr(ctx: Context) -> BusManager:
|
||||
return ctx.request_context.lifespan_context
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def send_notification(
|
||||
summary: str,
|
||||
@ -26,27 +24,30 @@ async def send_notification(
|
||||
icon: icon name or path (empty for default)
|
||||
timeout: display duration in milliseconds (default 5000)
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus("session")
|
||||
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination="org.freedesktop.Notifications",
|
||||
path="/org/freedesktop/Notifications",
|
||||
interface="org.freedesktop.Notifications",
|
||||
member="Notify",
|
||||
signature="susssasa{sv}i",
|
||||
body=[
|
||||
"mcdbus", # app_name
|
||||
0, # replaces_id (0 = new)
|
||||
icon, # app_icon
|
||||
summary, # summary
|
||||
body, # body
|
||||
[], # actions
|
||||
{}, # hints
|
||||
timeout, # expire_timeout
|
||||
],
|
||||
)
|
||||
try:
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination="org.freedesktop.Notifications",
|
||||
path="/org/freedesktop/Notifications",
|
||||
interface="org.freedesktop.Notifications",
|
||||
member="Notify",
|
||||
signature="susssasa{sv}i",
|
||||
body=[
|
||||
"mcdbus", # app_name
|
||||
0, # replaces_id (0 = new)
|
||||
icon, # app_icon
|
||||
summary, # summary
|
||||
body, # body
|
||||
[], # actions
|
||||
{}, # hints
|
||||
timeout, # expire_timeout
|
||||
],
|
||||
)
|
||||
except (RuntimeError, TimeoutError) as exc:
|
||||
return f"Error sending notification: {exc}"
|
||||
|
||||
nid = result[0] if result else "unknown"
|
||||
return f"Notification sent (id: {nid})"
|
||||
@ -64,16 +65,19 @@ async def list_systemd_units(
|
||||
bus: "session" (user units) or "system" (system units, default)
|
||||
pattern: optional glob filter (e.g. "docker*", "*.service")
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus(bus)
|
||||
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination="org.freedesktop.systemd1",
|
||||
path="/org/freedesktop/systemd1",
|
||||
interface="org.freedesktop.systemd1.Manager",
|
||||
member="ListUnits",
|
||||
)
|
||||
try:
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination="org.freedesktop.systemd1",
|
||||
path="/org/freedesktop/systemd1",
|
||||
interface="org.freedesktop.systemd1.Manager",
|
||||
member="ListUnits",
|
||||
)
|
||||
except (RuntimeError, TimeoutError) as exc:
|
||||
return f"Error listing units: {exc}"
|
||||
|
||||
if not result or not result[0]:
|
||||
return "No units found."
|
||||
@ -81,7 +85,6 @@ async def list_systemd_units(
|
||||
units = result[0]
|
||||
|
||||
# Each unit is (name, description, load_state, active_state, sub_state, ...)
|
||||
import fnmatch
|
||||
if pattern:
|
||||
units = [u for u in units if fnmatch.fnmatch(u[0], pattern)]
|
||||
|
||||
@ -107,11 +110,10 @@ async def media_player_control(
|
||||
action: one of "play", "pause", "next", "previous", "stop", "play-pause"
|
||||
player: MPRIS player name (auto-discovers first player if empty)
|
||||
"""
|
||||
mgr = _get_mgr(ctx)
|
||||
mgr = get_mgr(ctx)
|
||||
connection = await mgr.get_bus("session")
|
||||
|
||||
if not player:
|
||||
# Discover MPRIS players
|
||||
result = await call_bus_method(
|
||||
connection,
|
||||
destination="org.freedesktop.DBus",
|
||||
@ -120,12 +122,18 @@ async def media_player_control(
|
||||
member="ListNames",
|
||||
)
|
||||
names = result[0] if result else []
|
||||
mpris_services = [n for n in names if n.startswith("org.mpris.MediaPlayer2.")]
|
||||
mpris_services = sorted(
|
||||
n for n in names if n.startswith("org.mpris.MediaPlayer2.")
|
||||
)
|
||||
|
||||
if not mpris_services:
|
||||
return "No MPRIS media players found on session bus."
|
||||
player = mpris_services[0]
|
||||
await ctx.info(f"Auto-discovered player: {player}")
|
||||
others = (
|
||||
f"\nOther players: {', '.join(mpris_services[1:])}"
|
||||
if len(mpris_services) > 1 else ""
|
||||
)
|
||||
await ctx.info(f"Auto-discovered player: {player}{others}")
|
||||
|
||||
action_map = {
|
||||
"play": "Play",
|
||||
@ -140,13 +148,16 @@ async def media_player_control(
|
||||
if not dbus_method:
|
||||
return f"Unknown action: {action}. Use: {', '.join(action_map.keys())}"
|
||||
|
||||
await call_bus_method(
|
||||
connection,
|
||||
destination=player,
|
||||
path="/org/mpris/MediaPlayer2",
|
||||
interface="org.mpris.MediaPlayer2.Player",
|
||||
member=dbus_method,
|
||||
)
|
||||
try:
|
||||
await call_bus_method(
|
||||
connection,
|
||||
destination=player,
|
||||
path="/org/mpris/MediaPlayer2",
|
||||
interface="org.mpris.MediaPlayer2.Player",
|
||||
member=dbus_method,
|
||||
)
|
||||
except (RuntimeError, TimeoutError) as exc:
|
||||
return f"Error controlling player: {exc}"
|
||||
|
||||
# Read current playback status
|
||||
try:
|
||||
@ -160,7 +171,7 @@ async def media_player_control(
|
||||
body=["org.mpris.MediaPlayer2.Player", "PlaybackStatus"],
|
||||
)
|
||||
status = status_result[0] if status_result else "unknown"
|
||||
except Exception:
|
||||
except (RuntimeError, TimeoutError):
|
||||
status = "unknown"
|
||||
|
||||
# Try to get current track metadata
|
||||
@ -179,7 +190,7 @@ async def media_player_control(
|
||||
artist = metadata.get("xesam:artist", ["Unknown"])
|
||||
if isinstance(artist, list):
|
||||
artist = ", ".join(artist)
|
||||
except Exception:
|
||||
except (RuntimeError, TimeoutError):
|
||||
title, artist = "Unknown", "Unknown"
|
||||
|
||||
return f"Player: {player}\nAction: {action}\nStatus: {status}\nNow: {artist} — {title}"
|
||||
|
||||
@ -2,7 +2,13 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from mcdbus._bus import BusManager, deserialize_args, serialize_variant
|
||||
from mcdbus._bus import (
|
||||
BusManager,
|
||||
deserialize_args,
|
||||
serialize_variant,
|
||||
validate_bus_name,
|
||||
validate_object_path,
|
||||
)
|
||||
|
||||
|
||||
class TestBusManager:
|
||||
@ -71,3 +77,37 @@ class TestDeserializeArgs:
|
||||
def test_single_value_wraps_in_list(self):
|
||||
result = deserialize_args('"hello"', "s")
|
||||
assert result == ["hello"]
|
||||
|
||||
def test_arg_count_mismatch_raises(self):
|
||||
with pytest.raises(ValueError, match="expects 1 argument"):
|
||||
deserialize_args('["hello", "extra"]', "s")
|
||||
|
||||
def test_too_few_args_raises(self):
|
||||
with pytest.raises(ValueError, match="expects 2 argument"):
|
||||
deserialize_args('["hello"]', "su")
|
||||
|
||||
def test_invalid_json_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid JSON"):
|
||||
deserialize_args("{broken", "s")
|
||||
|
||||
|
||||
class TestValidation:
|
||||
def test_valid_bus_name(self):
|
||||
validate_bus_name("org.freedesktop.DBus")
|
||||
validate_bus_name("org.kde.KWin")
|
||||
|
||||
def test_invalid_bus_name(self):
|
||||
with pytest.raises(ValueError, match="Invalid D-Bus"):
|
||||
validate_bus_name("not-a-bus-name")
|
||||
with pytest.raises(ValueError, match="Invalid D-Bus"):
|
||||
validate_bus_name("")
|
||||
|
||||
def test_valid_object_path(self):
|
||||
validate_object_path("/")
|
||||
validate_object_path("/org/freedesktop/DBus")
|
||||
|
||||
def test_invalid_object_path(self):
|
||||
with pytest.raises(ValueError, match="Invalid D-Bus object path"):
|
||||
validate_object_path("not/a/path")
|
||||
with pytest.raises(ValueError, match="Invalid D-Bus object path"):
|
||||
validate_object_path("")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user