diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 3113ddb79..d102fc1e6 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -15,7 +15,7 @@ from opentelemetry import trace as trace_api -from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent +from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, BeforeStreamChunkEvent, MessageAddedEvent from ..telemetry.metrics import Trace from ..telemetry.tracer import Tracer, get_tracer from ..tools._validator import validate_and_prepare_tools @@ -39,7 +39,7 @@ MaxTokensReachedException, StructuredOutputException, ) -from ..types.streaming import StopReason +from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolResult, ToolUse from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached from ._retry import ModelRetryStrategy @@ -327,6 +327,18 @@ async def _handle_model_execution( tool_specs = [tool_spec] if tool_spec else [] else: tool_specs = agent.tool_registry.get_all_tool_specs() + + # Create chunk interceptor that invokes BeforeStreamChunkEvent hook + async def chunk_interceptor(chunk: StreamEvent) -> tuple[StreamEvent, bool]: + """Intercept chunks and invoke BeforeStreamChunkEvent hook.""" + stream_chunk_event = BeforeStreamChunkEvent( + agent=agent, + chunk=chunk, + invocation_state=invocation_state, + ) + await agent.hooks.invoke_callbacks_async(stream_chunk_event) + return stream_chunk_event.chunk, stream_chunk_event.skip + try: async for event in stream_messages( agent.model, @@ -336,6 +348,7 @@ async def _handle_model_execution( system_prompt_content=agent._system_prompt_content, tool_choice=structured_output_context.tool_choice, invocation_state=invocation_state, + chunk_interceptor=chunk_interceptor, ): yield event diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index b157f740e..76ee6156f 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -4,7 +4,7 @@ import logging import time import warnings -from collections.abc import AsyncGenerator, AsyncIterable +from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable from typing import Any from ..models.model import Model @@ -41,6 +41,10 @@ logger = logging.getLogger(__name__) +# Type for chunk interceptor callback +# Takes a chunk and returns (modified_chunk, skip) where skip=True means don't process this chunk +ChunkInterceptor = Callable[[StreamEvent], Awaitable[tuple[StreamEvent, bool]]] + def _normalize_messages(messages: Messages) -> Messages: """Remove or replace blank text in message content. @@ -368,13 +372,18 @@ def extract_usage_metrics(event: MetadataEvent, time_to_first_byte_ms: int | Non async def process_stream( - chunks: AsyncIterable[StreamEvent], start_time: float | None = None + chunks: AsyncIterable[StreamEvent], + start_time: float | None = None, + chunk_interceptor: ChunkInterceptor | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Processes the response stream from the API, constructing the final message and extracting usage metrics. Args: chunks: The chunks of the response stream from the model. start_time: Time when the model request is initiated + chunk_interceptor: Optional callback to intercept and modify chunks before processing. + The callback receives a chunk and returns (modified_chunk, skip). If skip is True, + the chunk is not processed or yielded. Yields: The reason for stopping, the constructed message, and the usage metrics. @@ -395,6 +404,12 @@ async def process_stream( metrics: Metrics = Metrics(latencyMs=0, timeToFirstByteMs=0) async for chunk in chunks: + # Invoke chunk interceptor BEFORE processing if provided + if chunk_interceptor is not None: + chunk, skip = await chunk_interceptor(chunk) + if skip: + continue + # Track first byte time when we get first content if first_byte_time is None and ("contentBlockDelta" in chunk or "contentBlockStart" in chunk): first_byte_time = time.time() @@ -431,6 +446,7 @@ async def stream_messages( tool_choice: Any | None = None, system_prompt_content: list[SystemContentBlock] | None = None, invocation_state: dict[str, Any] | None = None, + chunk_interceptor: ChunkInterceptor | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. @@ -444,6 +460,9 @@ async def stream_messages( system_prompt_content: The authoritative system prompt content blocks that always contains the system prompt data. invocation_state: Caller-provided state/context that was passed to the agent when it was invoked. + chunk_interceptor: Optional callback to intercept and modify chunks before processing. + The callback receives a chunk and returns (modified_chunk, skip). If skip is True, + the chunk is not processed or yielded. **kwargs: Additional keyword arguments for future extensibility. Yields: @@ -463,5 +482,5 @@ async def stream_messages( invocation_state=invocation_state, ) - async for event in process_stream(chunks, start_time): + async for event in process_stream(chunks, start_time, chunk_interceptor): yield event diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 96c7f577b..6b21390aa 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -41,6 +41,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: BeforeModelCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, + BeforeStreamChunkEvent, BeforeToolCallEvent, MessageAddedEvent, MultiAgentInitializedEvent, @@ -50,6 +51,7 @@ def log_end(self, event: AfterInvocationEvent) -> None: __all__ = [ "AgentInitializedEvent", "BeforeInvocationEvent", + "BeforeStreamChunkEvent", "BeforeToolCallEvent", "AfterToolCallEvent", "BeforeModelCallEvent", diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 8d3e5d280..75356bff0 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -14,7 +14,7 @@ from ..types.content import Message, Messages from ..types.interrupt import _Interruptible -from ..types.streaming import StopReason +from ..types.streaming import StopReason, StreamEvent from ..types.tools import AgentTool, ToolResult, ToolUse from .registry import BaseHookEvent, HookEvent @@ -288,6 +288,48 @@ def should_reverse_callbacks(self) -> bool: return True +@dataclass +class BeforeStreamChunkEvent(HookEvent): + """Event triggered before each stream chunk is processed. + + This event is fired for each chunk received from the model BEFORE the chunk + is processed for message building or yielded as stream events. Hook providers + can use this event to: + + - Monitor streaming progress in real-time + - Modify chunk content before processing (affects final message and all events) + - Filter/skip chunks entirely by setting skip=True + - Implement content transformation (e.g., redaction, translation) + + When skip=True: + - The chunk is not processed at all + - No events (ModelStreamChunkEvent, TextStreamEvent, etc.) are yielded + - The chunk does not contribute to the final message + + When chunk is modified: + - The modified chunk is used for all downstream processing + - TextStreamEvent will contain the modified text + - The final message will contain the modified content + + Performance Note: + This event fires for every stream chunk, so callbacks should execute + quickly to avoid impacting streaming latency. + + Attributes: + chunk: The raw stream event from the model. Can be modified by hooks + to transform content before processing. + skip: When True, the chunk is skipped entirely (not processed or yielded). + invocation_state: State passed through agent invocation. + """ + + chunk: StreamEvent + invocation_state: dict[str, Any] = field(default_factory=dict) + skip: bool = False + + def _can_write(self, name: str) -> bool: + return name in ["chunk", "skip"] + + # Multiagent hook events start here @dataclass class MultiAgentInitializedEvent(BaseHookEvent): diff --git a/tests/strands/agent/hooks/test_events.py b/tests/strands/agent/hooks/test_events.py index de551d137..14c179bf3 100644 --- a/tests/strands/agent/hooks/test_events.py +++ b/tests/strands/agent/hooks/test_events.py @@ -10,6 +10,7 @@ AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeStreamChunkEvent, BeforeToolCallEvent, MessageAddedEvent, ) @@ -230,3 +231,58 @@ def test_before_invocation_event_agent_not_writable(start_request_event_with_mes """Test that BeforeInvocationEvent.agent is not writable.""" with pytest.raises(AttributeError, match="Property agent is not writable"): start_request_event_with_messages.agent = Mock() + + +@pytest.fixture +def before_stream_chunk_event(agent): + chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + return BeforeStreamChunkEvent( + agent=agent, + chunk=chunk, + invocation_state={"test": "state"}, + ) + + +def test_before_stream_chunk_event_should_not_reverse_callbacks(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent does not reverse callbacks.""" + assert before_stream_chunk_event.should_reverse_callbacks is False + + +def test_before_stream_chunk_event_can_write_chunk(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.chunk is writable.""" + new_chunk = {"contentBlockDelta": {"delta": {"text": "Modified"}}} + before_stream_chunk_event.chunk = new_chunk + assert before_stream_chunk_event.chunk == new_chunk + + +def test_before_stream_chunk_event_can_write_skip(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.skip is writable.""" + assert before_stream_chunk_event.skip is False + before_stream_chunk_event.skip = True + assert before_stream_chunk_event.skip is True + + +def test_before_stream_chunk_event_cannot_write_agent(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.agent is not writable.""" + with pytest.raises(AttributeError, match="Property agent is not writable"): + before_stream_chunk_event.agent = Mock() + + +def test_before_stream_chunk_event_cannot_write_invocation_state(before_stream_chunk_event): + """Test that BeforeStreamChunkEvent.invocation_state is not writable.""" + with pytest.raises(AttributeError, match="Property invocation_state is not writable"): + before_stream_chunk_event.invocation_state = {} + + +def test_before_stream_chunk_event_skip_defaults_to_false(agent): + """Test that BeforeStreamChunkEvent.skip defaults to False.""" + chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + event = BeforeStreamChunkEvent(agent=agent, chunk=chunk) + assert event.skip is False + + +def test_before_stream_chunk_event_invocation_state_defaults_to_empty(agent): + """Test that BeforeStreamChunkEvent.invocation_state defaults to empty dict.""" + chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} + event = BeforeStreamChunkEvent(agent=agent, chunk=chunk) + assert event.invocation_state == {} diff --git a/tests/strands/agent/test_agent_hooks.py b/tests/strands/agent/test_agent_hooks.py index 4397b9628..25317fc44 100644 --- a/tests/strands/agent/test_agent_hooks.py +++ b/tests/strands/agent/test_agent_hooks.py @@ -12,6 +12,7 @@ AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, + BeforeStreamChunkEvent, BeforeToolCallEvent, MessageAddedEvent, ) @@ -694,3 +695,177 @@ async def capture_messages_hook(event: BeforeInvocationEvent): # structured_output_async uses deprecated path that doesn't pass messages assert received_messages is None + + +def test_before_stream_chunk_event_fires_for_each_chunk(): + """Test that BeforeStreamChunkEvent fires for each stream chunk.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello world"}], + }, + ] + ) + + chunk_events = [] + + async def capture_stream_chunks(event: BeforeStreamChunkEvent): + chunk_events.append(event.chunk) + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_stream_chunks) + + agent("Test message") + + # Should have received multiple chunk events (messageStart, contentBlockStart, delta, stop, etc.) + assert len(chunk_events) > 0 + # Should include content block delta with text + text_deltas = [c for c in chunk_events if "contentBlockDelta" in c] + assert len(text_deltas) == 1 + assert text_deltas[0]["contentBlockDelta"]["delta"]["text"] == "Hello world" + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_can_skip_chunks(): + """Test that setting skip=True prevents chunks from being processed entirely. + + When skip=True, the chunk is not processed at all: + - No ModelStreamChunkEvent is yielded + - No TextStreamEvent is yielded + - The chunk does not contribute to the final message + """ + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello world"}], + }, + ] + ) + + async def skip_text_chunks(event: BeforeStreamChunkEvent): + # Skip only contentBlockDelta chunks (text content) + if "contentBlockDelta" in event.chunk: + event.skip = True + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, skip_text_chunks) + + # Collect all yielded events + text_events = [] + result = None + async for event in agent.stream_async("Test message"): + if "data" in event: # TextStreamEvent + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # Verify no text events were yielded (because we skipped content deltas) + assert len(text_events) == 0 + + # Verify final message has no text content (skipped chunks don't contribute) + # The message will have empty content since we skipped the text blocks + assert result.message["content"] == [] + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_can_modify_chunks(): + """Test that chunk modifications affect both stream events and final message. + + The BeforeStreamChunkEvent hook intercepts chunks BEFORE processing, so + modifications affect: + - The yielded ModelStreamChunkEvent (raw chunk) + - The yielded TextStreamEvent (processed text) + - The final message content + """ + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "secret"}], + }, + ] + ) + + async def redact_text_chunks(event: BeforeStreamChunkEvent): + # Modify text delta chunks + if "contentBlockDelta" in event.chunk: + delta = event.chunk["contentBlockDelta"]["delta"] + if "text" in delta: + # Create modified chunk with redacted text + event.chunk = {"contentBlockDelta": {"delta": {"text": "[REDACTED]"}}} + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, redact_text_chunks) + + # Collect yielded events + text_events = [] + result = None + async for event in agent.stream_async("Test message"): + if "data" in event: # TextStreamEvent + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # Verify TextStreamEvent contains modified text + assert len(text_events) == 1 + assert text_events[0] == "[REDACTED]" + + # Verify final message contains modified text + assert result.message["content"][0]["text"] == "[REDACTED]" + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_with_stream_async(): + """Test that BeforeStreamChunkEvent works with stream_async.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello"}], + }, + ] + ) + + chunk_events = [] + + async def capture_stream_chunks(event: BeforeStreamChunkEvent): + chunk_events.append(event.chunk) + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_stream_chunks) + + async for _ in agent.stream_async("Test message"): + pass + + # Should have received chunk events + assert len(chunk_events) > 0 + + +def test_before_stream_chunk_event_has_invocation_state(): + """Test that BeforeStreamChunkEvent includes invocation_state.""" + mock_provider = MockedModelProvider( + [ + { + "role": "assistant", + "content": [{"text": "Hello"}], + }, + ] + ) + + received_states = [] + + async def capture_invocation_state(event: BeforeStreamChunkEvent): + received_states.append(event.invocation_state.copy()) + + agent = Agent(model=mock_provider) + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_invocation_state) + + agent("Test message", invocation_state={"custom_key": "custom_value"}) + + # All captured states should have the custom key + assert len(received_states) > 0 + for state in received_states: + assert "custom_key" in state + assert state["custom_key"] == "custom_value" diff --git a/tests_integ/hooks/test_stream_chunk_event.py b/tests_integ/hooks/test_stream_chunk_event.py new file mode 100644 index 000000000..60e485e98 --- /dev/null +++ b/tests_integ/hooks/test_stream_chunk_event.py @@ -0,0 +1,122 @@ +"""Integration tests for BeforeStreamChunkEvent hook.""" + +import pytest + +from strands import Agent +from strands.hooks import BeforeStreamChunkEvent + + +@pytest.fixture +def chunks_intercepted(): + return [] + + +@pytest.fixture +def agent_with_stream_hook(chunks_intercepted): + """Create an agent with BeforeStreamChunkEvent hook registered.""" + + async def intercept_chunks(event: BeforeStreamChunkEvent): + chunks_intercepted.append(event.chunk.copy()) + + agent = Agent(system_prompt="Be very brief. Reply with one word only.") + agent.hooks.add_callback(BeforeStreamChunkEvent, intercept_chunks) + return agent + + +def test_before_stream_chunk_event_fires(agent_with_stream_hook, chunks_intercepted): + """Test that BeforeStreamChunkEvent fires for each stream chunk.""" + agent_with_stream_hook("Say hello") + + # Should have intercepted multiple chunks + assert len(chunks_intercepted) > 0 + + # Should have message start, content blocks, and message stop + chunk_types = set() + for chunk in chunks_intercepted: + chunk_types.update(chunk.keys()) + + assert "messageStart" in chunk_types + assert "messageStop" in chunk_types + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_modification(): + """Test that chunk modifications affect both stream events and final message.""" + modified_text = "[REDACTED]" + + async def redact_chunks(event: BeforeStreamChunkEvent): + if "contentBlockDelta" in event.chunk: + delta = event.chunk.get("contentBlockDelta", {}).get("delta", {}) + if "text" in delta: + event.chunk = {"contentBlockDelta": {"delta": {"text": modified_text}}} + + agent = Agent(system_prompt="Say exactly: secret123") + agent.hooks.add_callback(BeforeStreamChunkEvent, redact_chunks) + + text_events = [] + result = None + + async for event in agent.stream_async("go"): + if "data" in event: + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # All text events should be the modified text + assert all(text == modified_text for text in text_events) + + # Final message should only contain modified text + final_text = result.message["content"][0].get("text", "") + assert modified_text in final_text + assert "secret" not in final_text.lower() + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_skip(): + """Test that skip=True excludes chunks from processing and final message.""" + + async def skip_content_deltas(event: BeforeStreamChunkEvent): + # Skip all content block deltas (text content) + if "contentBlockDelta" in event.chunk: + event.skip = True + + agent = Agent(system_prompt="Say hello") + agent.hooks.add_callback(BeforeStreamChunkEvent, skip_content_deltas) + + text_events = [] + result = None + + async for event in agent.stream_async("go"): + if "data" in event: + text_events.append(event["data"]) + if "result" in event: + result = event["result"] + + # No text events should be yielded + assert len(text_events) == 0 + + # Final message should have no content (all text was skipped) + assert result.message["content"] == [] + + +@pytest.mark.asyncio +async def test_before_stream_chunk_event_has_invocation_state(): + """Test that invocation_state is accessible in BeforeStreamChunkEvent.""" + received_states = [] + + async def capture_state(event: BeforeStreamChunkEvent): + received_states.append(event.invocation_state.copy()) + + agent = Agent(system_prompt="Be brief") + agent.hooks.add_callback(BeforeStreamChunkEvent, capture_state) + + custom_state = {"session_id": "test-123", "user_id": "user-456"} + + async for _ in agent.stream_async("hi", invocation_state=custom_state): + pass + + # All captured states should have our custom keys + assert len(received_states) > 0 + for state in received_states: + assert state.get("session_id") == "test-123" + assert state.get("user_id") == "user-456"