diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py index d0a17875875e..ce7b5e032d64 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/config/__init__.py @@ -1,4 +1,4 @@ -from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union from autogen_core import ComponentModel from autogen_core.models import ModelCapabilities, ModelInfo # type: ignore @@ -57,6 +57,9 @@ class CreateArguments(TypedDict, total=False): - 'low': Faster responses with less reasoning - 'medium': Balanced reasoning and speed - 'high': More thorough reasoning, may take longer""" + extra_body: Optional[Dict[str, Any]] + """Extra JSON body fields to include in the API request. + Useful for provider-specific parameters (e.g., ``enable_thinking``).""" AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]] @@ -108,6 +111,8 @@ class CreateArgumentsConfigModel(BaseModel): parallel_tool_calls: bool | None = None # Controls the amount of effort the model uses for reasoning (reasoning models only) reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None + # Extra JSON body fields to include in the API request (e.g., enable_thinking) + extra_body: Dict[str, Any] | None = None class BaseOpenAIClientConfigurationConfigModel(CreateArgumentsConfigModel): diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index ba79795d1ed7..a3c888eaed70 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -3378,3 +3378,50 @@ async def test_reasoning_effort_validation() -> None: } ChatCompletionClient.load_component(config) + + +@pytest.mark.asyncio +async def test_extra_body_load_component() -> None: + """Test that extra_body survives load_component() round-trip (issue #7418).""" + from autogen_core.models import ChatCompletionClient + + # Direct instantiation should work (baseline) + client = OpenAIChatCompletionClient( + model="gpt-4o", + api_key="fake_key", + extra_body={"enable_thinking": False}, + ) + assert client._create_args["extra_body"] == {"enable_thinking": False} # pyright: ignore[reportPrivateUsage] + + # load_component should preserve extra_body + openai_config = { + "provider": "OpenAIChatCompletionClient", + "config": { + "model": "gpt-4o", + "api_key": "fake_key", + "extra_body": {"enable_thinking": False}, + }, + } + loaded_client = ChatCompletionClient.load_component(openai_config) + assert loaded_client._create_args["extra_body"] == {"enable_thinking": False} # type: ignore[attr-defined] # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAttributeAccessIssue] + assert loaded_client._raw_config["extra_body"] == {"enable_thinking": False} # type: ignore[attr-defined] # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAttributeAccessIssue] + + # dump_component -> load_component round-trip + config_dict = client.dump_component() + reloaded_client = OpenAIChatCompletionClient.load_component(config_dict) + assert reloaded_client._create_args["extra_body"] == {"enable_thinking": False} # pyright: ignore[reportPrivateUsage] + + # Azure variant + azure_config = { + "provider": "AzureOpenAIChatCompletionClient", + "config": { + "model": "gpt-4o", + "azure_endpoint": "https://fake.openai.azure.com", + "azure_deployment": "gpt-4o", + "api_version": "2024-06-01", + "api_key": "fake_key", + "extra_body": {"enable_thinking": False}, + }, + } + loaded_azure = ChatCompletionClient.load_component(azure_config) + assert loaded_azure._create_args["extra_body"] == {"enable_thinking": False} # type: ignore[attr-defined] # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportAttributeAccessIssue]