diff --git a/.semversioner/next-release/patch-20260312083316079043.json b/.semversioner/next-release/patch-20260312083316079043.json new file mode 100644 index 000000000..cf1919fcf --- /dev/null +++ b/.semversioner/next-release/patch-20260312083316079043.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fix structured response parsing when providers return tool calls with empty message content" +} diff --git a/packages/graphrag-llm/graphrag_llm/types/types.py b/packages/graphrag-llm/graphrag_llm/types/types.py index 0980cba3a..be948e591 100644 --- a/packages/graphrag-llm/graphrag_llm/types/types.py +++ b/packages/graphrag-llm/graphrag_llm/types/types.py @@ -97,8 +97,27 @@ class LLMCompletionResponse(ChatCompletion, Generic[ResponseFormat]): @computed_field @property def content(self) -> str: - """Get the content of the first choice message.""" - return self.choices[0].message.content or "" + """Get the content of the first choice message. + + Falls back to function tool-call arguments when `content` is empty. + Some providers return structured payloads via tool calls instead of + assistant message content. + """ + if not self.choices: + return "" + + message = self.choices[0].message + if message.content: + return message.content + + if message.tool_calls: + for tool_call in message.tool_calls: + if tool_call.type != "function": + continue + if tool_call.function.arguments: + return tool_call.function.arguments + + return "" class LLMCompletionArgs( diff --git a/packages/graphrag-llm/graphrag_llm/utils/gather_completion_response.py b/packages/graphrag-llm/graphrag_llm/utils/gather_completion_response.py index 0722e95ef..040c55147 100644 --- a/packages/graphrag-llm/graphrag_llm/utils/gather_completion_response.py +++ b/packages/graphrag-llm/graphrag_llm/utils/gather_completion_response.py @@ -30,7 +30,7 @@ def gather_completion_response( if isinstance(response, Iterator): return "".join(chunk.choices[0].delta.content or "" for chunk in response) - return response.choices[0].message.content or "" + return response.content async def gather_completion_response_async( @@ -54,4 +54,4 @@ async def gather_completion_response_async( return gathered_content - return response.choices[0].message.content or "" + return response.content diff --git a/tests/unit/utils/test_completion_response.py b/tests/unit/utils/test_completion_response.py new file mode 100644 index 000000000..fbc05cb2c --- /dev/null +++ b/tests/unit/utils/test_completion_response.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +import asyncio + +from graphrag_llm.types import LLMCompletionResponse +from graphrag_llm.utils import ( + gather_completion_response, + gather_completion_response_async, + structure_completion_response, +) +from pydantic import BaseModel + + +class RatingResponse(BaseModel): + rating: int + + +def _create_completion_response( + *, + content: str | None, + tool_call_arguments: str | None = None, +) -> LLMCompletionResponse: + message: dict = { + "role": "assistant", + "content": content, + } + + if tool_call_arguments is not None: + message["tool_calls"] = [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "structured_output", + "arguments": tool_call_arguments, + }, + } + ] + + return LLMCompletionResponse( + id="completion-id", + object="chat.completion", + created=0, + model="mock-model", + choices=[ + { + "index": 0, + "message": message, + "finish_reason": "stop", + } + ], + usage={ + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + }, + formatted_response=None, + ) + + +def test_content_prefers_message_content() -> None: + response = _create_completion_response( + content="plain text", + tool_call_arguments='{"rating": 9}', + ) + + assert response.content == "plain text" + + +def test_content_falls_back_to_function_tool_call_arguments() -> None: + response = _create_completion_response( + content=None, + tool_call_arguments='{"rating": 7}', + ) + + assert response.content == '{"rating": 7}' + + +def test_gather_completion_response_falls_back_to_tool_call_arguments() -> None: + response = _create_completion_response( + content=None, + tool_call_arguments='{"rating": 3}', + ) + + assert gather_completion_response(response) == '{"rating": 3}' + + +def test_gather_completion_response_async_falls_back_to_tool_call_arguments() -> None: + response = _create_completion_response( + content=None, + tool_call_arguments='{"rating": 5}', + ) + + gathered_response = asyncio.run(gather_completion_response_async(response)) + assert gathered_response == '{"rating": 5}' + + +def test_structure_completion_response_uses_tool_call_arguments() -> None: + response = _create_completion_response( + content=None, + tool_call_arguments='{"rating": 11}', + ) + + parsed = structure_completion_response(response.content, RatingResponse) + assert parsed.rating == 11