diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index 188eb9e71..e4572056e 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -import copy import json import logging import re @@ -59,6 +58,16 @@ # ============================================================================ # Helper Functions for Tool Call Context Management # ============================================================================ +def _serialize_messages(messages: List[ChatMessage]) -> List[Dict]: + """Materialize chat messages into JSON-safe dicts for checkpoint-stable memory.""" + return [message.model_dump(mode="json") for message in messages] + + +def _deserialize_messages(messages: List[Dict]) -> List[ChatMessage]: + """Reconstruct chat messages from their stored JSON-safe dict form.""" + return [ChatMessage.model_validate(message) for message in messages] + + def _update_tool_call_context( sensory_memory: MemoryObject, initial_request_id: UUID, @@ -81,12 +90,12 @@ def _update_tool_call_context( key = str(initial_request_id) tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {} if key not in tool_call_context and initial_messages is not None: - tool_call_context[key] = copy.deepcopy(initial_messages) + tool_call_context[key] = _serialize_messages(initial_messages) - tool_call_context[key].extend(added_messages) + tool_call_context[key].extend(_serialize_messages(added_messages)) sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) - return tool_call_context[key] + return _deserialize_messages(tool_call_context[key]) def _save_tool_request_event_context( @@ -100,10 +109,12 @@ def _save_tool_request_event_context( """Save the context for a specific tool request event.""" context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {} context[str(tool_request_event_id)] = { - "initial_request_id": initial_request_id, + "initial_request_id": str(initial_request_id), "model": model, _PROMPT_ARGS: prompt_args if prompt_args is not None else {}, - "output_schema": output_schema, + "output_schema": output_schema.model_dump() + if output_schema is not None + else None, } sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context) @@ -114,6 +125,16 @@ def _get_tool_request_event_context( """Get and remove the context for a specific tool request event.""" context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {} removed_context = context.pop(str(request_id), {}) + if removed_context: + removed_context["initial_request_id"] = UUID( + removed_context["initial_request_id"] + ) + output_schema = removed_context["output_schema"] + removed_context["output_schema"] = ( + OutputSchema.model_validate(output_schema) + if output_schema is not None + else None + ) return removed_context diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action.py b/python/flink_agents/plan/tests/actions/test_chat_model_action.py index 72b223d28..94ff8ab9c 100644 --- a/python/flink_agents/plan/tests/actions/test_chat_model_action.py +++ b/python/flink_agents/plan/tests/actions/test_chat_model_action.py @@ -15,7 +15,49 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from flink_agents.plan.actions.chat_model_action import _clean_llm_response +from uuid import uuid4 + +from pydantic import BaseModel +from pyflink.common.typeinfo import BasicTypeInfo, RowTypeInfo + +from flink_agents.api.agents.react_agent import OutputSchema +from flink_agents.api.chat_message import ChatMessage, MessageRole +from flink_agents.api.memory_object import MemoryType +from flink_agents.plan.actions.chat_model_action import ( + _TOOL_CALL_CONTEXT, + _TOOL_REQUEST_EVENT_CONTEXT, + _clean_llm_response, + _get_tool_request_event_context, + _save_tool_request_event_context, + _update_tool_call_context, +) +from flink_agents.runtime.local_memory_object import LocalMemoryObject + + +def _memory() -> LocalMemoryObject: + return LocalMemoryObject(MemoryType.SHORT_TERM, {}) + + +def _assert_primitive(obj) -> None: + if obj is None or isinstance(obj, bool | int | float | str | bytes): + return + if isinstance(obj, list): + for item in obj: + _assert_primitive(item) + return + if isinstance(obj, dict): + for k, v in obj.items(): + assert isinstance(k, str | int | float | bool), ( + f"non-primitive key: {k!r}" + ) + _assert_primitive(v) + return + msg = f"non-primitive value of type {type(obj).__name__}: {obj!r}" + raise AssertionError(msg) + + +class _Result(BaseModel): + result: int def test_clean_llm_response_with_json_block(): @@ -52,3 +94,116 @@ def test_clean_llm_response_with_multiple_lines_in_block(): input_str = '```json\n{\n "key": "value"\n}\n```' expected = '{\n "key": "value"\n}' assert _clean_llm_response(input_str) == expected + + +def test_update_tool_call_context_stores_primitive_only(): + mem = _memory() + initial = [ChatMessage(role=MessageRole.USER, content="hi")] + added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")] + _update_tool_call_context(mem, uuid4(), initial, added) + _assert_primitive(mem.get(_TOOL_CALL_CONTEXT)) + + +def test_update_tool_call_context_returns_chat_messages(): + mem = _memory() + initial = [ChatMessage(role=MessageRole.USER, content="hi")] + added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")] + result = _update_tool_call_context(mem, uuid4(), initial, added) + assert all(isinstance(message, ChatMessage) for message in result) + assert [(m.role, m.content) for m in result] == [ + (MessageRole.USER, "hi"), + (MessageRole.ASSISTANT, "hello"), + ] + + +def test_tool_request_event_context_stores_primitive_only(): + mem = _memory() + _save_tool_request_event_context( + mem, + uuid4(), + uuid4(), + "ollama", + {"k": "v"}, + OutputSchema(output_schema=_Result), + ) + _assert_primitive(mem.get(_TOOL_REQUEST_EVENT_CONTEXT)) + + +def test_tool_request_event_context_round_trip(): + mem = _memory() + event_id = uuid4() + initial_request_id = uuid4() + _save_tool_request_event_context( + mem, + event_id, + initial_request_id, + "ollama", + None, + OutputSchema(output_schema=_Result), + ) + context = _get_tool_request_event_context(mem, event_id) + assert context["initial_request_id"] == initial_request_id + assert isinstance(context["initial_request_id"], type(initial_request_id)) + assert isinstance(context["output_schema"], OutputSchema) + assert context["output_schema"].output_schema is _Result + assert context["model"] == "ollama" + + +def test_get_context_none_output_schema(): + mem = _memory() + event_id = uuid4() + _save_tool_request_event_context(mem, event_id, uuid4(), "ollama", None, None) + assert mem.get(_TOOL_REQUEST_EVENT_CONTEXT)[str(event_id)]["output_schema"] is None + context = _get_tool_request_event_context(mem, event_id) + assert context["output_schema"] is None + + +def test_request_event_key_match_after_normalization(): + mem = _memory() + event_id = uuid4() + _save_tool_request_event_context( + mem, event_id, uuid4(), "ollama", None, None + ) + context = _get_tool_request_event_context(mem, event_id) + assert context != {} + assert context["model"] == "ollama" + + +def test_tool_call_context_key_match_after_normalization(): + mem = _memory() + request_id = uuid4() + initial = [ChatMessage(role=MessageRole.USER, content="hi")] + _update_tool_call_context(mem, request_id, initial, []) + extra = ChatMessage(role=MessageRole.TOOL, content="result") + result = _update_tool_call_context(mem, request_id, None, [extra]) + assert len(result) == 2 + assert len(mem.get(_TOOL_CALL_CONTEXT)[str(request_id)]) == 2 + + +def test_output_schema_rowtypeinfo_round_trip(): + mem = _memory() + event_id = uuid4() + schema = OutputSchema( + output_schema=RowTypeInfo( + [BasicTypeInfo.INT_TYPE_INFO()], + ["result"], + ) + ) + _save_tool_request_event_context( + mem, event_id, uuid4(), "ollama", None, schema + ) + context = _get_tool_request_event_context(mem, event_id) + assert isinstance(context["output_schema"], OutputSchema) + assert context["output_schema"].output_schema.get_field_names() == ["result"] + + +def test_save_get_preserves_model_and_prompt_args(): + mem = _memory() + event_id = uuid4() + prompt_args = {"a": 1, "b": "x"} + _save_tool_request_event_context( + mem, event_id, uuid4(), "ollama", prompt_args, None + ) + context = _get_tool_request_event_context(mem, event_id) + assert context["model"] == "ollama" + assert context["prompt_args"] == prompt_args diff --git a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py index e1f04f622..e9e799da3 100644 --- a/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py +++ b/python/flink_agents/plan/tests/actions/test_chat_model_action_retry.py @@ -311,7 +311,7 @@ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: "_TOOL_REQUEST_EVENT_CONTEXT", { str(tool_request_event_id): { - "initial_request_id": initial_request_id, + "initial_request_id": str(initial_request_id), "model": "test-model", "prompt_args": saved_prompt_args, "output_schema": None, @@ -325,7 +325,9 @@ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: "_TOOL_CALL_CONTEXT", { str(initial_request_id): [ - ChatMessage(role=MessageRole.USER, content="hi") + ChatMessage( + role=MessageRole.USER, content="hi" + ).model_dump(mode="json") ] }, )