diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f5d00a201..1898b9bb7 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -152,6 +152,11 @@ async def event_loop_cycle( agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context ) async for model_event in model_events: + if isinstance(model_event, EventLoopStopEvent): + agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + yield model_event + await model_events.aclose() # clean-up async for-loop to avoid CancelledError + return if not isinstance(model_event, ModelStopReason): yield model_event @@ -368,6 +373,18 @@ async def _handle_model_execution( stop_reason, ) continue # Retry the model call + elif after_model_call_event.terminate: + logger.debug( + "stop_reason=<%s>, termination_requested= | hook requested agent termination", + stop_reason, + ) + invocation_state["request_state"]["stop_event_loop"] = True + yield EventLoopStopEvent( + stop_reason, + message, + agent.event_loop_metrics, + invocation_state["request_state"], + ) if stop_reason == "max_tokens": message = recover_message_on_max_tokens_reached(message) diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 1faa8a917..3f9ea323c 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -245,9 +245,13 @@ class ModelStopResponse: stop_response: ModelStopResponse | None = None exception: Exception | None = None retry: bool = False + terminate: bool = False def _can_write(self, name: str) -> bool: - return name == "retry" + return name in ( + "retry", + "terminate", + ) @property def should_reverse_callbacks(self) -> bool: diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index e8b7e5077..7db9a937a 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -686,3 +686,104 @@ async def capture_messages_hook(event: BeforeInvocationEvent): # structured_output_async uses deprecated path that doesn't pass messages assert received_messages is None + + +@pytest.mark.asyncio +async def test_hook_terminate_on_successful_call(): + """Test that hooks can terminate even on successful model calls based on response content.""" + + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "First conversation successful"}], + }, + { + "role": "assistant", + "content": [{"text": "Unnecessary follow-up conversation"}], + }, + ] + ) + + # Hook that terminate if response is favorable + class SuccessfulTerminateHook: + def __init__(self, end_marker="success"): + self.end_marker = end_marker + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + + # Check successful responses for favorable markers + if event.stop_response: + message = event.stop_response.message + text_content = "".join(block.get("text", "") for block in message.get("content", [])) + + if self.end_marker in text_content: + event.terminate = True + + terminate_hook = SuccessfulTerminateHook(end_marker="success") + agent = Agent(model=mock_provider, hooks=[terminate_hook]) + + result = agent("Generate a response") + + # Verify hook was called only once (For first favorable response) + assert terminate_hook.call_count == 1 + + # Verify final result is the favorable response + assert result.message["content"][0]["text"] == "First conversation successful" + + +@pytest.mark.asyncio +async def test_hook_terminate_gracefully_on_limits(agent_tool, tool_use): + """Test that hooks can terminate agent gracefully after maximum counts reached.""" + + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "First tool-use"}, {"toolUse": tool_use}], + }, + { + "role": "assistant", + "content": [{"text": "Second tool-use"}, {"toolUse": tool_use}], + }, + { + "role": "assistant", + "content": [{"text": "Third tool-use"}, {"toolUse": tool_use}], + }, + ] + ) + + # Hook that counts number of calls + class GracefulTerminateHook: + def __init__(self, max_counts): + self.max_counts = max_counts + self.call_count = 0 + + def register_hooks(self, registry): + registry.add_callback(strands.hooks.AfterModelCallEvent, self.handle_after_model_call) + + async def handle_after_model_call(self, event): + self.call_count += 1 + + if self.call_count > self.max_counts - 1: + event.terminate = True + + terminate_hook = GracefulTerminateHook(max_counts=2) + agent = Agent( + model=mock_provider, + tools=[agent_tool], + hooks=[terminate_hook], + ) + + result = agent("Generate a response") + + # Verify hook was called two times + assert terminate_hook.call_count == 2 + + # Verify final result is the second tool-use + assert result.message["content"][0]["text"] == "Second tool-use"