diff --git a/src/mcdbus/_interaction.py b/src/mcdbus/_interaction.py index c2a427b..5eb727a 100644 --- a/src/mcdbus/_interaction.py +++ b/src/mcdbus/_interaction.py @@ -25,13 +25,19 @@ from mcdbus._state import mcp async def _confirm_or_abort(ctx: Context, message: str, operation: str) -> None: """Elicit user confirmation; raise ToolError or return silently to proceed.""" - result = await ctx.elicit( - message, - { - "confirm": {"title": "Yes, proceed"}, - "deny": {"title": "No, cancel"}, - }, - ) + try: + result = await ctx.elicit( + message, + { + "confirm": {"title": "Yes, proceed"}, + "deny": {"title": "No, cancel"}, + }, + ) + except Exception: + # Client doesn't implement elicitation at the protocol level — + # ctx.elicit() throws instead of returning CancelledElicitation. + result = CancelledElicitation() + if isinstance(result, AcceptedElicitation) and result.data == "confirm": return if isinstance(result, CancelledElicitation): diff --git a/tests/test_elicitation.py b/tests/test_elicitation.py index 36f383d..56afbc4 100644 --- a/tests/test_elicitation.py +++ b/tests/test_elicitation.py @@ -48,6 +48,21 @@ class TestConfirmOrAbort: with pytest.raises(ToolError, match="Elicitation required"): await _confirm_or_abort(ctx, "Test message", "test_op") + async def test_elicit_exception_treated_as_cancelled(self): + """When ctx.elicit() throws (client lacks protocol support), proceed.""" + ctx = AsyncMock() + ctx.elicit = AsyncMock(side_effect=Exception("Method not found")) + # Should not raise — exception is treated as CancelledElicitation + await _confirm_or_abort(ctx, "Test message", "test_op") + + async def test_elicit_exception_hard_fails_when_required(self): + """When ctx.elicit() throws and MCDBUS_REQUIRE_ELICITATION is set, fail.""" + ctx = AsyncMock() + ctx.elicit = AsyncMock(side_effect=Exception("Method not found")) + with patch.dict(os.environ, {"MCDBUS_REQUIRE_ELICITATION": "1"}): + with pytest.raises(ToolError, match="Elicitation required"): + await _confirm_or_abort(ctx, "Test message", "test_op") + class TestCallMethodElicitation: """Verify call_method triggers elicitation on system bus but not session bus."""