diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fe6b1c351d..ff5a7e7314 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -97,7 +97,11 @@ async def _init_or_reload_subagent_orchestrator(self) -> None: except Exception as e: logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True) - async def initialize(self) -> None: + async def initialize( + self, + *, + mcp_init_timeout: float | int | str | None = None, + ) -> None: """初始化 AstrBot 核心生命周期管理类. 负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 @@ -201,7 +205,7 @@ async def initialize(self) -> None: await self.plugin_manager.reload() # 根据配置实例化各个 Provider - await self.provider_manager.initialize() + await self.provider_manager.initialize(init_timeout=mcp_init_timeout) await self.kb_manager.initialize() diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 22e9a0766c..4239a8b470 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -347,7 +347,10 @@ def _log_safe_mcp_debug_config(cfg: dict) -> None: logger.debug(f" 主机: {scheme}://{host}{port}") async def init_mcp_clients( - self, raise_on_all_failed: bool = False + self, + raise_on_all_failed: bool = False, + *, + init_timeout: float | int | str | None = None, ) -> MCPInitSummary: """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: ``` @@ -368,6 +371,7 @@ async def init_mcp_clients( ``` Timeout behavior: + - 显式 `init_timeout` 参数优先(用于测试或调用方覆盖)。 - 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值。 - 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT(独立于初始化超时)。 """ @@ -393,8 +397,12 @@ async def init_mcp_clients( "mcpServers" ] - init_timeout = self._init_timeout_default - timeout_display = f"{init_timeout:g}" + init_timeout_value = _resolve_timeout( + timeout=init_timeout, + env_name=MCP_INIT_TIMEOUT_ENV, + default=self._init_timeout_default, + ) + timeout_display = f"{init_timeout_value:g}" active_configs: list[tuple[str, dict, asyncio.Event]] = [] for name, cfg in mcp_server_json_obj.items(): @@ -413,7 +421,7 @@ async def init_mcp_clients( name=name, cfg=cfg, shutdown_event=shutdown_event, - timeout=init_timeout, + timeout_seconds=init_timeout_value, ), name=f"mcp-init:{name}", ) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ef89027e8b..ccdb1f36b3 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -269,7 +269,11 @@ def get_using_provider( return provider - async def initialize(self) -> None: + async def initialize( + self, + *, + init_timeout: float | int | str | None = None, + ) -> None: # 逐个初始化提供商 for provider_config in self.providers_config: try: @@ -338,7 +342,8 @@ async def initialize(self) -> None: "on", } mcp_init_summary = await self.llm_tools.init_mcp_clients( - raise_on_all_failed=strict_mcp_init + raise_on_all_failed=strict_mcp_init, + init_timeout=init_timeout, ) if ( mcp_init_summary.total > 0 diff --git a/tests/unit/test_core_lifecycle.py b/tests/unit/test_core_lifecycle.py index fc8300bf96..52cac816aa 100644 --- a/tests/unit/test_core_lifecycle.py +++ b/tests/unit/test_core_lifecycle.py @@ -373,7 +373,7 @@ async def test_initialize_sets_up_all_components( new_callable=AsyncMock, ), ): - await lifecycle.initialize() + await lifecycle.initialize(mcp_init_timeout=3.5) # Verify database initialized mock_db.initialize.assert_awaited_once() @@ -388,7 +388,7 @@ async def test_initialize_sets_up_all_components( mock_persona_mgr.initialize.assert_awaited_once() # Verify provider manager initialized - mock_provider_manager.initialize.assert_awaited_once() + mock_provider_manager.initialize.assert_awaited_once_with(init_timeout=3.5) # Verify platform manager initialized mock_platform_manager.initialize.assert_awaited_once() diff --git a/tests/unit/test_func_tool_manager.py b/tests/unit/test_func_tool_manager.py new file mode 100644 index 0000000000..8151416c28 --- /dev/null +++ b/tests/unit/test_func_tool_manager.py @@ -0,0 +1,98 @@ +import json + +import pytest + +from astrbot.core.provider import func_tool_manager +from astrbot.core.provider.func_tool_manager import FunctionToolManager + + +@pytest.fixture +def mcp_init_harness( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +): + manager = FunctionToolManager() + data_dir = tmp_path / "data" + data_dir.mkdir() + + (data_dir / "mcp_server.json").write_text( + json.dumps({"mcpServers": {"demo": {"active": True}}}), + encoding="utf-8", + ) + monkeypatch.setattr( + func_tool_manager, + "get_astrbot_data_path", + lambda: data_dir, + ) + + called = {} + + async def fake_start_mcp_server(*, name, cfg, shutdown_event, timeout_seconds): + called[name] = { + "cfg": cfg, + "shutdown_event_type": type(shutdown_event).__name__, + "timeout_seconds": timeout_seconds, + } + + monkeypatch.setattr(manager, "_start_mcp_server", fake_start_mcp_server) + return manager, called + + +def assert_demo_init_result(summary, called, *, timeout_seconds: float) -> None: + assert summary.total == 1 + assert summary.success == 1 + assert summary.failed == [] + assert called["demo"]["cfg"] == {"active": True} + assert called["demo"]["shutdown_event_type"] == "Event" + assert called["demo"]["timeout_seconds"] == timeout_seconds + + +@pytest.mark.asyncio +async def test_init_mcp_clients_passes_timeout_seconds_keyword(mcp_init_harness): + manager, called = mcp_init_harness + + summary = await manager.init_mcp_clients() + + assert_demo_init_result( + summary, + called, + timeout_seconds=manager._init_timeout_default, + ) + + +@pytest.mark.asyncio +async def test_init_mcp_clients_passes_overridden_init_timeout( + mcp_init_harness, +): + manager, called = mcp_init_harness + + summary = await manager.init_mcp_clients(init_timeout=3.5) + + assert_demo_init_result(summary, called, timeout_seconds=3.5) + + +@pytest.mark.asyncio +async def test_init_mcp_clients_reads_env_timeout_when_not_overridden( + mcp_init_harness, + monkeypatch: pytest.MonkeyPatch, +): + manager, called = mcp_init_harness + manager._init_timeout_default = 20.0 # ensure env override is observable + monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "3.5") + + summary = await manager.init_mcp_clients() + + assert_demo_init_result(summary, called, timeout_seconds=3.5) + + +@pytest.mark.asyncio +async def test_init_mcp_clients_prefers_explicit_timeout_over_env( + mcp_init_harness, + monkeypatch: pytest.MonkeyPatch, +): + manager, called = mcp_init_harness + monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "7.0") + + summary = await manager.init_mcp_clients(init_timeout=3.5) + + assert_demo_init_result(summary, called, timeout_seconds=3.5)