diff --git a/changelog/1001.fixed.md b/changelog/1001.fixed.md new file mode 100644 index 00000000..bb9c57f4 --- /dev/null +++ b/changelog/1001.fixed.md @@ -0,0 +1 @@ +Fix `InfrahubBatch.execute()` orphaning in-flight tasks when one task raises with `return_exceptions=False`. diff --git a/infrahub_sdk/batch.py b/infrahub_sdk/batch.py index 6e9cc1cb..5166ef86 100644 --- a/infrahub_sdk/batch.py +++ b/infrahub_sdk/batch.py @@ -80,11 +80,20 @@ async def execute(self) -> AsyncGenerator: for batch_task in self._tasks ] - for completed_task in asyncio.as_completed(tasks): - node, result = await completed_task - if isinstance(result, Exception) and not self.return_exceptions: - raise result - yield node, result + try: + for completed_task in asyncio.as_completed(tasks): + node, result = await completed_task + if isinstance(result, Exception) and not self.return_exceptions: + raise result + yield node, result + finally: + # Ensure no task created here outlives execute(). Cancel any still + # running, then drain so their exceptions are retrieved instead of + # surfacing later as "Task exception was never retrieved". + for t in tasks: + if not t.done(): + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) class InfrahubBatchSync: diff --git a/tests/unit/sdk/test_batch.py b/tests/unit/sdk/test_batch.py index 7bdf00ad..85ff445e 100644 --- a/tests/unit/sdk/test_batch.py +++ b/tests/unit/sdk/test_batch.py @@ -1,9 +1,11 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING import pytest +from infrahub_sdk.batch import InfrahubBatch from infrahub_sdk.exceptions import GraphQLError if TYPE_CHECKING: @@ -125,3 +127,91 @@ async def test_batch_exception( for _, _ in batch.execute(): pass assert "An error occurred while executing the GraphQL Query" in str(exc.value) + + +async def test_execute_does_not_orphan_inflight_tasks_when_raising() -> None: + """When one batch task raises and return_exceptions=False, sibling tasks + that are still in flight must not be left running. If they are orphaned, + their work-in-progress side effects continue after the caller has been + told the batch failed, and unretrieved exceptions surface later as + "Task exception was never retrieved" in the asyncio log. + """ + side_effects: list[str] = [] + + async def raise_fast() -> None: + # Yield once so siblings get scheduled, then fail before they finish. + await asyncio.sleep(0) + raise RuntimeError("fast failure") + + async def slow_side_effect(name: str) -> str: + await asyncio.sleep(0.1) + side_effects.append(name) + return name + + batch = InfrahubBatch(max_concurrent_execution=10, return_exceptions=False) + batch.add(task=raise_fast) + for i in range(5): + batch.add(task=slow_side_effect, name=f"slow-{i}") + + with pytest.raises(RuntimeError, match="fast failure"): + async for _ in batch.execute(): + pass + + # Wait long enough that any orphan would have completed. + await asyncio.sleep(0.3) + + assert side_effects == [], ( + f"sibling tasks were orphaned and ran to completion after execute() raised: {side_effects}" + ) + + +async def test_return_exceptions_yields_exceptions_indistinguishably_from_successes() -> None: + """Pins down the current contract of ``execute()`` with ``return_exceptions=True``. + + Failures are yielded as ``(node, ExceptionInstance)`` using the same tuple + shape as successes ``(node, result)``. The yielded ``node`` is whatever the + caller passed via ``batch.add(..., node=...)`` regardless of outcome, so a + consumer that does not ``isinstance``-check ``result`` cannot tell a failed + task from a successful one and will silently treat both as "created". + + This test is expected to change when the API shape is reworked (e.g., a + ``BatchResult`` dataclass with separate ``result``/``exception`` fields, or + a split API where successes and failures are surfaced on different paths). + """ + sentinel_a = object() + sentinel_b = object() + + async def succeed() -> str: + return "ok" + + async def fail() -> None: + raise RuntimeError("boom") + + batch = InfrahubBatch(max_concurrent_execution=10, return_exceptions=True) + batch.add(task=succeed, node=sentinel_a) + batch.add(task=fail, node=sentinel_b) + + yielded: list[tuple[object, object]] = [] + async for node, result in batch.execute(): + yielded.append((node, result)) + + by_node = {id(n): r for n, r in yielded} + + # Both tasks yield, and the tuple shape is identical. + assert len(yielded) == 2 + assert {id(sentinel_a), id(sentinel_b)} == set(by_node.keys()) + + # Successful yield: result is the task's return value. + assert by_node[id(sentinel_a)] == "ok" + + # Failed yield: result is the exception instance, in the same slot. The + # only way to distinguish failure from success is an isinstance check on + # the result. The node slot is unchanged from what the caller supplied. + failed_result = by_node[id(sentinel_b)] + assert isinstance(failed_result, RuntimeError) + assert str(failed_result) == "boom" + + # Demonstrate the silent-data-loss pitfall: a naive caller that records + # ``node`` per yield treats the failed task as if it succeeded. + naive_created = [n for n, _ in yielded] + assert sentinel_b in naive_created # node retained despite the underlying failure