Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions python/flink_agents/plan/actions/chat_model_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import copy
import json
import logging
import re
Expand Down Expand Up @@ -59,6 +58,16 @@
# ============================================================================
# Helper Functions for Tool Call Context Management
# ============================================================================
def _serialize_messages(messages: List[ChatMessage]) -> List[Dict]:
"""Materialize chat messages into JSON-safe dicts for checkpoint-stable memory."""
return [message.model_dump(mode="json") for message in messages]


def _deserialize_messages(messages: List[Dict]) -> List[ChatMessage]:
"""Reconstruct chat messages from their stored JSON-safe dict form."""
return [ChatMessage.model_validate(message) for message in messages]


def _update_tool_call_context(
sensory_memory: MemoryObject,
initial_request_id: UUID,
Expand All @@ -81,12 +90,12 @@ def _update_tool_call_context(
key = str(initial_request_id)
tool_call_context = sensory_memory.get(_TOOL_CALL_CONTEXT) or {}
if key not in tool_call_context and initial_messages is not None:
tool_call_context[key] = copy.deepcopy(initial_messages)
tool_call_context[key] = _serialize_messages(initial_messages)

tool_call_context[key].extend(added_messages)
tool_call_context[key].extend(_serialize_messages(added_messages))

sensory_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
return tool_call_context[key]
return _deserialize_messages(tool_call_context[key])


def _save_tool_request_event_context(
Expand All @@ -100,10 +109,12 @@ def _save_tool_request_event_context(
"""Save the context for a specific tool request event."""
context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
context[str(tool_request_event_id)] = {
"initial_request_id": initial_request_id,
"initial_request_id": str(initial_request_id),
"model": model,
_PROMPT_ARGS: prompt_args if prompt_args is not None else {},
"output_schema": output_schema,
"output_schema": output_schema.model_dump()
if output_schema is not None
else None,
}
sensory_memory.set(_TOOL_REQUEST_EVENT_CONTEXT, context)

Expand All @@ -114,6 +125,16 @@ def _get_tool_request_event_context(
"""Get and remove the context for a specific tool request event."""
context = sensory_memory.get(_TOOL_REQUEST_EVENT_CONTEXT) or {}
removed_context = context.pop(str(request_id), {})
if removed_context:
removed_context["initial_request_id"] = UUID(
removed_context["initial_request_id"]
)
output_schema = removed_context["output_schema"]
removed_context["output_schema"] = (
OutputSchema.model_validate(output_schema)
if output_schema is not None
else None
)
return removed_context


Expand Down
157 changes: 156 additions & 1 deletion python/flink_agents/plan/tests/actions/test_chat_model_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
from flink_agents.plan.actions.chat_model_action import _clean_llm_response
from uuid import uuid4

from pydantic import BaseModel
from pyflink.common.typeinfo import BasicTypeInfo, RowTypeInfo

from flink_agents.api.agents.react_agent import OutputSchema
from flink_agents.api.chat_message import ChatMessage, MessageRole
from flink_agents.api.memory_object import MemoryType
from flink_agents.plan.actions.chat_model_action import (
_TOOL_CALL_CONTEXT,
_TOOL_REQUEST_EVENT_CONTEXT,
_clean_llm_response,
_get_tool_request_event_context,
_save_tool_request_event_context,
_update_tool_call_context,
)
from flink_agents.runtime.local_memory_object import LocalMemoryObject


def _memory() -> LocalMemoryObject:
return LocalMemoryObject(MemoryType.SHORT_TERM, {})


def _assert_primitive(obj) -> None:
if obj is None or isinstance(obj, bool | int | float | str | bytes):
return
if isinstance(obj, list):
for item in obj:
_assert_primitive(item)
return
if isinstance(obj, dict):
for k, v in obj.items():
assert isinstance(k, str | int | float | bool), (
f"non-primitive key: {k!r}"
)
_assert_primitive(v)
return
msg = f"non-primitive value of type {type(obj).__name__}: {obj!r}"
raise AssertionError(msg)


class _Result(BaseModel):
result: int


def test_clean_llm_response_with_json_block():
Expand Down Expand Up @@ -52,3 +94,116 @@ def test_clean_llm_response_with_multiple_lines_in_block():
input_str = '```json\n{\n "key": "value"\n}\n```'
expected = '{\n "key": "value"\n}'
assert _clean_llm_response(input_str) == expected


def test_update_tool_call_context_stores_primitive_only():
mem = _memory()
initial = [ChatMessage(role=MessageRole.USER, content="hi")]
added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")]
_update_tool_call_context(mem, uuid4(), initial, added)
_assert_primitive(mem.get(_TOOL_CALL_CONTEXT))


def test_update_tool_call_context_returns_chat_messages():
mem = _memory()
initial = [ChatMessage(role=MessageRole.USER, content="hi")]
added = [ChatMessage(role=MessageRole.ASSISTANT, content="hello")]
result = _update_tool_call_context(mem, uuid4(), initial, added)
assert all(isinstance(message, ChatMessage) for message in result)
assert [(m.role, m.content) for m in result] == [
(MessageRole.USER, "hi"),
(MessageRole.ASSISTANT, "hello"),
]


def test_tool_request_event_context_stores_primitive_only():
mem = _memory()
_save_tool_request_event_context(
mem,
uuid4(),
uuid4(),
"ollama",
{"k": "v"},
OutputSchema(output_schema=_Result),
)
_assert_primitive(mem.get(_TOOL_REQUEST_EVENT_CONTEXT))


def test_tool_request_event_context_round_trip():
mem = _memory()
event_id = uuid4()
initial_request_id = uuid4()
_save_tool_request_event_context(
mem,
event_id,
initial_request_id,
"ollama",
None,
OutputSchema(output_schema=_Result),
)
context = _get_tool_request_event_context(mem, event_id)
assert context["initial_request_id"] == initial_request_id
assert isinstance(context["initial_request_id"], type(initial_request_id))
assert isinstance(context["output_schema"], OutputSchema)
assert context["output_schema"].output_schema is _Result
assert context["model"] == "ollama"


def test_get_context_none_output_schema():
mem = _memory()
event_id = uuid4()
_save_tool_request_event_context(mem, event_id, uuid4(), "ollama", None, None)
assert mem.get(_TOOL_REQUEST_EVENT_CONTEXT)[str(event_id)]["output_schema"] is None
context = _get_tool_request_event_context(mem, event_id)
assert context["output_schema"] is None


def test_request_event_key_match_after_normalization():
mem = _memory()
event_id = uuid4()
_save_tool_request_event_context(
mem, event_id, uuid4(), "ollama", None, None
)
context = _get_tool_request_event_context(mem, event_id)
assert context != {}
assert context["model"] == "ollama"


def test_tool_call_context_key_match_after_normalization():
mem = _memory()
request_id = uuid4()
initial = [ChatMessage(role=MessageRole.USER, content="hi")]
_update_tool_call_context(mem, request_id, initial, [])
extra = ChatMessage(role=MessageRole.TOOL, content="result")
result = _update_tool_call_context(mem, request_id, None, [extra])
assert len(result) == 2
assert len(mem.get(_TOOL_CALL_CONTEXT)[str(request_id)]) == 2


def test_output_schema_rowtypeinfo_round_trip():
mem = _memory()
event_id = uuid4()
schema = OutputSchema(
output_schema=RowTypeInfo(
[BasicTypeInfo.INT_TYPE_INFO()],
["result"],
)
)
_save_tool_request_event_context(
mem, event_id, uuid4(), "ollama", None, schema
)
context = _get_tool_request_event_context(mem, event_id)
assert isinstance(context["output_schema"], OutputSchema)
assert context["output_schema"].output_schema.get_field_names() == ["result"]


def test_save_get_preserves_model_and_prompt_args():
mem = _memory()
event_id = uuid4()
prompt_args = {"a": 1, "b": "x"}
_save_tool_request_event_context(
mem, event_id, uuid4(), "ollama", prompt_args, None
)
context = _get_tool_request_event_context(mem, event_id)
assert context["model"] == "ollama"
assert context["prompt_args"] == prompt_args
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage:
"_TOOL_REQUEST_EVENT_CONTEXT",
{
str(tool_request_event_id): {
"initial_request_id": initial_request_id,
"initial_request_id": str(initial_request_id),
"model": "test-model",
"prompt_args": saved_prompt_args,
"output_schema": None,
Expand All @@ -325,7 +325,9 @@ def mock_chat(messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage:
"_TOOL_CALL_CONTEXT",
{
str(initial_request_id): [
ChatMessage(role=MessageRole.USER, content="hi")
ChatMessage(
role=MessageRole.USER, content="hi"
).model_dump(mode="json")
]
},
)
Expand Down
Loading