Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1001.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `InfrahubBatch.execute()` orphaning in-flight tasks when one task raises with `return_exceptions=False`.
19 changes: 14 additions & 5 deletions infrahub_sdk/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,20 @@ async def execute(self) -> AsyncGenerator:
for batch_task in self._tasks
]

for completed_task in asyncio.as_completed(tasks):
node, result = await completed_task
if isinstance(result, Exception) and not self.return_exceptions:
raise result
yield node, result
try:
for completed_task in asyncio.as_completed(tasks):
node, result = await completed_task
if isinstance(result, Exception) and not self.return_exceptions:
raise result
yield node, result
finally:
# Ensure no task created here outlives execute(). Cancel any still
# running, then drain so their exceptions are retrieved instead of
# surfacing later as "Task exception was never retrieved".
for t in tasks:
if not t.done():
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


class InfrahubBatchSync:
Expand Down
90 changes: 90 additions & 0 deletions tests/unit/sdk/test_batch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import pytest

from infrahub_sdk.batch import InfrahubBatch
from infrahub_sdk.exceptions import GraphQLError

if TYPE_CHECKING:
Expand Down Expand Up @@ -125,3 +127,91 @@ async def test_batch_exception(
for _, _ in batch.execute():
pass
assert "An error occurred while executing the GraphQL Query" in str(exc.value)


async def test_execute_does_not_orphan_inflight_tasks_when_raising() -> None:
"""When one batch task raises and return_exceptions=False, sibling tasks
that are still in flight must not be left running. If they are orphaned,
their work-in-progress side effects continue after the caller has been
told the batch failed, and unretrieved exceptions surface later as
"Task exception was never retrieved" in the asyncio log.
"""
side_effects: list[str] = []

async def raise_fast() -> None:
# Yield once so siblings get scheduled, then fail before they finish.
await asyncio.sleep(0)
raise RuntimeError("fast failure")

async def slow_side_effect(name: str) -> str:
await asyncio.sleep(0.1)
side_effects.append(name)
return name

batch = InfrahubBatch(max_concurrent_execution=10, return_exceptions=False)
batch.add(task=raise_fast)
for i in range(5):
batch.add(task=slow_side_effect, name=f"slow-{i}")

with pytest.raises(RuntimeError, match="fast failure"):
async for _ in batch.execute():
pass

# Wait long enough that any orphan would have completed.
await asyncio.sleep(0.3)

assert side_effects == [], (
f"sibling tasks were orphaned and ran to completion after execute() raised: {side_effects}"
)


async def test_return_exceptions_yields_exceptions_indistinguishably_from_successes() -> None:
"""Pins down the current contract of ``execute()`` with ``return_exceptions=True``.

Failures are yielded as ``(node, ExceptionInstance)`` using the same tuple
shape as successes ``(node, result)``. The yielded ``node`` is whatever the
caller passed via ``batch.add(..., node=...)`` regardless of outcome, so a
consumer that does not ``isinstance``-check ``result`` cannot tell a failed
task from a successful one and will silently treat both as "created".

This test is expected to change when the API shape is reworked (e.g., a
``BatchResult`` dataclass with separate ``result``/``exception`` fields, or
a split API where successes and failures are surfaced on different paths).
"""
sentinel_a = object()
sentinel_b = object()

async def succeed() -> str:
return "ok"

async def fail() -> None:
raise RuntimeError("boom")

batch = InfrahubBatch(max_concurrent_execution=10, return_exceptions=True)
batch.add(task=succeed, node=sentinel_a)
batch.add(task=fail, node=sentinel_b)

yielded: list[tuple[object, object]] = []
async for node, result in batch.execute():
yielded.append((node, result))

by_node = {id(n): r for n, r in yielded}

# Both tasks yield, and the tuple shape is identical.
assert len(yielded) == 2
assert {id(sentinel_a), id(sentinel_b)} == set(by_node.keys())

# Successful yield: result is the task's return value.
assert by_node[id(sentinel_a)] == "ok"

# Failed yield: result is the exception instance, in the same slot. The
# only way to distinguish failure from success is an isinstance check on
# the result. The node slot is unchanged from what the caller supplied.
failed_result = by_node[id(sentinel_b)]
assert isinstance(failed_result, RuntimeError)
assert str(failed_result) == "boom"

# Demonstrate the silent-data-loss pitfall: a naive caller that records
# ``node`` per yield treats the failed task as if it succeeded.
naive_created = [n for n, _ in yielded]
assert sentinel_b in naive_created # node retained despite the underlying failure