diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index db1878108..4a48d7229 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -363,6 +363,23 @@ def _inject_cache_point(self, messages: list[dict[str, Any]]) -> None: messages[last_assistant_idx]["content"].append({"cachePoint": {"type": "default"}}) logger.debug("msg_idx=<%s> | added cache point to last assistant message", last_assistant_idx) + def _find_last_user_text_message_index(self, messages: Messages) -> int | None: + """Find the index of the last user message containing text or image content. + + This is used for guardrail_latest_message to ensure that guardContent wrapping + targets the correct message even when toolResult messages follow. + + Args: + messages: List of messages to search + + Returns: + Index of the last user message with text/image content, or None if not found + """ + for idx, msg in reversed(list(enumerate(messages))): + if msg["role"] == "user" and any("text" in cb or "image" in cb for cb in msg.get("content", [])): + return idx + return None + def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format messages for Bedrock API compatibility. @@ -391,7 +408,12 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: filtered_unknown_members = False dropped_deepseek_reasoning_content = False - guardrail_latest_message = self.config.get("guardrail_latest_message", False) + # Pre-compute the index of the last user message containing text or image content. + # This ensures guardContent wrapping is maintained across tool execution cycles, where + # the final message in the list is a toolResult (role=user) rather than text/image content. + last_user_text_idx = None + if self.config.get("guardrail_latest_message", False): + last_user_text_idx = self._find_last_user_text_message_index(messages) for idx, message in enumerate(messages): cleaned_content: list[dict[str, Any]] = [] @@ -413,13 +435,8 @@ def _format_bedrock_messages(self, messages: Messages) -> list[dict[str, Any]]: if formatted_content is None: continue - # Wrap text or image content in guardrailContent if this is the last user message - if ( - guardrail_latest_message - and idx == len(messages) - 1 - and message["role"] == "user" - and ("text" in formatted_content or "image" in formatted_content) - ): + # Wrap text or image content in guardContent if this is the last user text/image message + if idx == last_user_text_idx and ("text" in formatted_content or "image" in formatted_content): if "text" in formatted_content: formatted_content = {"guardContent": {"text": {"text": formatted_content["text"]}}} elif "image" in formatted_content: diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 228d6c138..9dae16be7 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -2405,6 +2405,183 @@ async def test_format_request_with_guardrail_latest_message(model): assert formatted_messages[2]["content"][1]["guardContent"]["image"]["format"] == "png" +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message_after_tool_use(model): + """Test that guardContent wraps the last user text message even when a toolResult follows it.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "what is the standard deduction?"}]}, + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "knowledge_base", + "input": {"query": "standard deduction"}, + } + } + ], + }, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tool-1", + "content": [{"text": "The standard deduction for 2024 is $14,600."}], + "status": "success", + } + } + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + assert len(formatted_messages) == 5 + + # Earlier user message should NOT be wrapped + assert "text" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["text"] == "First message" + + # Last user message with text content should be wrapped, even though a toolResult comes after + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "what is the standard deduction?" + + # toolResult-only user message should NOT be wrapped + assert "toolResult" in formatted_messages[4]["content"][0] + assert "guardContent" not in formatted_messages[4]["content"][0] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_latest_message_wraps_final_user_text(model): + """Test that guardContent wraps the last user message when it contains text content.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First message"}]}, + {"role": "assistant", "content": [{"text": "First response"}]}, + {"role": "user", "content": [{"text": "Tell me about taxes"}]}, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + assert "guardContent" in formatted_messages[2]["content"][0] + assert formatted_messages[2]["content"][0]["guardContent"]["text"]["text"] == "Tell me about taxes" + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_multiple_sequential_tool_calls(model): + """Test guardContent with multiple tool calls in sequence (no new user input between).""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "First question"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}}], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Should wrap the first user text message, not the toolResults + assert "guardContent" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "First question" + + # toolResults should not be wrapped + assert "toolResult" in formatted_messages[2]["content"][0] + assert "guardContent" not in formatted_messages[2]["content"][0] + assert "toolResult" in formatted_messages[4]["content"][0] + assert "guardContent" not in formatted_messages[4]["content"][0] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_image_before_tool_result(model): + """Test guardContent wraps image content even when toolResult follows.""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "I see a cat"}], "status": "success"}}], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Image should be wrapped even though toolResult comes after + assert "guardContent" in formatted_messages[0]["content"][0] + assert "image" in formatted_messages[0]["content"][0]["guardContent"] + + +@pytest.mark.asyncio +async def test_format_request_with_guardrail_multiple_tool_results_same_message(model): + """Test guardContent with multiple parallel tool calls (multiple toolResults in one message).""" + model.update_config( + guardrail_id="test-guardrail", + guardrail_version="DRAFT", + guardrail_latest_message=True, + ) + + messages = [ + {"role": "user", "content": [{"text": "Question requiring multiple tools"}]}, + { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "t1", "name": "tool1", "input": {}}}, + {"toolUse": {"toolUseId": "t2", "name": "tool2", "input": {}}}, + ], + }, + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "t1", "content": [{"text": "Result 1"}], "status": "success"}}, + {"toolResult": {"toolUseId": "t2", "content": [{"text": "Result 2"}], "status": "success"}}, + ], + }, + ] + + request = model._format_request(messages) + formatted_messages = request["messages"] + + # Should wrap the question + assert "guardContent" in formatted_messages[0]["content"][0] + assert formatted_messages[0]["content"][0]["guardContent"]["text"]["text"] == "Question requiring multiple tools" + + def test_supports_caching_true_for_claude(bedrock_client): """Test that supports_caching returns True for Claude models.""" model = BedrockModel(model_id="us.anthropic.claude-sonnet-4-20250514-v1:0") @@ -2514,3 +2691,93 @@ def test_inject_cache_point_strips_existing_cache_points(bedrock_client): # New cache point should be at end of last assistant message assert len(cleaned_messages[3]["content"]) == 2 assert "cachePoint" in cleaned_messages[3]["content"][-1] + + +def test_find_last_user_text_message_index_no_user_messages(bedrock_client): + """Test _find_last_user_text_message_index returns None when no user text messages exist.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "assistant", "content": [{"text": "hello"}]}, + ] + + assert model._find_last_user_text_message_index(messages) is None + + +def test_find_last_user_text_message_index_only_tool_results(bedrock_client): + """Test _find_last_user_text_message_index returns None when user messages only have toolResult.""" + model = BedrockModel(model_id="test-model") + + messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) is None + + +def test_find_last_user_text_message_index_returns_last_text_message(bedrock_client): + """Test _find_last_user_text_message_index returns the index of the last user message with text.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "First question"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + {"role": "user", "content": [{"text": "Second question"}]}, + ] + + assert model._find_last_user_text_message_index(messages) == 2 + + +def test_find_last_user_text_message_index_skips_tool_result_messages(bedrock_client): + """Test _find_last_user_text_message_index skips toolResult-only user messages.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"text": "Question"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "tool", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) == 0 + + +def test_find_last_user_text_message_index_finds_image_message(bedrock_client): + """Test _find_last_user_text_message_index finds user messages with image content.""" + model = BedrockModel(model_id="test-model") + + messages = [ + {"role": "user", "content": [{"image": {"format": "png", "source": {"bytes": b"fake"}}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "t1", "name": "vision", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "t1", "content": [{"text": "Result"}]}}], + }, + ] + + assert model._find_last_user_text_message_index(messages) == 0 + + +def test_find_last_user_text_message_index_empty_messages(bedrock_client): + """Test _find_last_user_text_message_index returns None for empty message list.""" + model = BedrockModel(model_id="test-model") + + assert model._find_last_user_text_message_index([]) is None + + +def test_guardrail_latest_message_disabled_does_not_wrap(model): + """Test that guardContent wrapping is skipped when guardrail_latest_message is not set.""" + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + ] + + request = model._format_request(messages) + formatted = request["messages"][0]["content"][0] + + assert "text" in formatted + assert "guardContent" not in formatted