Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions py_src/taskito/workflows/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from typing import TYPE_CHECKING, Any

from .builder import Workflow, WorkflowProxy
from .run import WorkflowRun
from taskito.workflows.builder import Workflow, WorkflowProxy
from taskito.workflows.incremental import compute_dirty_set
from taskito.workflows.run import WorkflowRun
from taskito.workflows.tracker.dag import build_dag_maps

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -55,12 +57,9 @@ def submit_workflow(
# Compute cache hits for incremental runs.
cache_hit_nodes: dict[str, str] | None = None
if incremental and base_run:
from .incremental import compute_dirty_set
from .tracker import _build_dag_maps

base_nodes = self._inner.get_base_run_node_data(base_run)
_, preds = _build_dag_maps(dag_bytes)
succs, _ = _build_dag_maps(dag_bytes)
_, preds = build_dag_maps(dag_bytes)
succs, _ = build_dag_maps(dag_bytes)

# Get base run completion time for TTL check.
base_status = self._inner.get_workflow_run_status(base_run)
Expand Down
875 changes: 0 additions & 875 deletions py_src/taskito/workflows/tracker.py

This file was deleted.

5 changes: 5 additions & 0 deletions py_src/taskito/workflows/tracker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Workflow tracker package — orchestrates dynamic node execution."""

from taskito.workflows.tracker.tracker import WorkflowTracker

__all__ = ["WorkflowTracker"]
188 changes: 188 additions & 0 deletions py_src/taskito/workflows/tracker/dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Pure DAG helpers and node-action implementations.

Functions here are leaf helpers — they accept the tracker as the first
argument so they can read state and call back into tracker methods, but
they never import other tracker submodules. The tracker class wires them
together in its orchestrator methods.
"""

from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Any

from taskito.events import EventType
from taskito.workflows.context import WorkflowContext

if TYPE_CHECKING:
from taskito.workflows.tracker.tracker import WorkflowTracker
from taskito.workflows.tracker.types import _RunConfig


logger = logging.getLogger("taskito.workflows")


def build_dag_maps(
dag_bytes: bytes | list[int],
) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
"""Parse DAG JSON to build successor and predecessor maps."""
raw = bytes(dag_bytes) if isinstance(dag_bytes, list) else dag_bytes
dag_json = json.loads(raw)
successors: dict[str, list[str]] = {}
predecessors: dict[str, list[str]] = {}
for node in dag_json.get("nodes", []):
name = node["name"]
successors.setdefault(name, [])
predecessors.setdefault(name, [])
for edge in dag_json.get("edges", []):
successors.setdefault(edge["from"], []).append(edge["to"])
predecessors.setdefault(edge["to"], []).append(edge["from"])
return successors, predecessors


def int_or(value: Any, default: int) -> int:
return value if value is not None else default


def final_state_to_event(state: str) -> EventType | None:
if state == "completed":
return EventType.WORKFLOW_COMPLETED
if state == "failed":
return EventType.WORKFLOW_FAILED
if state == "cancelled":
return EventType.WORKFLOW_CANCELLED
return None


def all_predecessors_terminal(
tracker: WorkflowTracker, run_id: str, node_name: str, config: _RunConfig
) -> bool:
"""Check whether all predecessors have a terminal status."""
raw = tracker._queue._inner.get_workflow_run_status(run_id)
node_statuses = raw.node_statuses()
for pred in config.predecessors.get(node_name, []):
info = node_statuses.get(pred)
if info is None:
return False
status = info["status"]
if status not in ("completed", "failed", "skipped"):
return False
return True


def get_predecessor_statuses(
tracker: WorkflowTracker, run_id: str, node_name: str, config: _RunConfig
) -> dict[str, str]:
"""Return ``{pred_name: status_str}`` for all predecessors."""
raw = tracker._queue._inner.get_workflow_run_status(run_id)
node_statuses = raw.node_statuses()
result: dict[str, str] = {}
for pred in config.predecessors.get(node_name, []):
info = node_statuses.get(pred)
result[pred] = info["status"] if info else "pending"
return result


def build_workflow_context(
tracker: WorkflowTracker, run_id: str, config: _RunConfig
) -> WorkflowContext:
"""Build a :class:`WorkflowContext` from the current run state."""
raw = tracker._queue._inner.get_workflow_run_status(run_id)
node_statuses = raw.node_statuses()

results: dict[str, Any] = {}
statuses: dict[str, str] = {}
failure_count = 0
success_count = 0

for name, info in node_statuses.items():
status = info["status"]
statuses[name] = status
if status == "completed":
success_count += 1
jid = info.get("job_id")
if jid:
results[name] = tracker._fetch_result(jid)
elif status == "failed":
failure_count += 1

return WorkflowContext(
run_id=run_id,
results=results,
statuses=statuses,
params=None,
failure_count=failure_count,
success_count=success_count,
)


def should_execute(
tracker: WorkflowTracker, run_id: str, node_name: str, config: _RunConfig
) -> bool:
"""Decide whether a deferred node should execute based on its condition."""
callable_cond = config.callable_conditions.get(node_name)
if callable_cond is not None:
ctx = build_workflow_context(tracker, run_id, config)
try:
return bool(callable_cond(ctx))
except Exception:
logger.exception("callable condition failed for %s", node_name)
return False

meta = config.step_metadata.get(node_name, {})
condition = meta.get("condition")
pred_statuses = get_predecessor_statuses(tracker, run_id, node_name, config)

if condition is None or condition == "on_success":
return all(s == "completed" for s in pred_statuses.values())
if condition == "on_failure":
return any(s == "failed" for s in pred_statuses.values())
return bool(condition == "always")


def skip_and_propagate(
tracker: WorkflowTracker, run_id: str, node_name: str, config: _RunConfig
) -> None:
"""Mark a node as SKIPPED and recursively evaluate its successors."""
try:
tracker._queue._inner.skip_workflow_node(run_id, node_name)
except (RuntimeError, ValueError):
logger.exception("skip_workflow_node failed for %s", node_name)
return
config.deferred_nodes.discard(node_name)
tracker._evaluate_successors(run_id, node_name, config)


def create_deferred_job_for_node(
tracker: WorkflowTracker, run_id: str, node_name: str, config: _RunConfig
) -> None:
"""Create a job for a deferred node and record the mapping."""
payload = config.deferred_payloads.get(node_name)
if payload is None: # pragma: no cover
logger.error("no cached payload for deferred node %s", node_name)
return
meta = config.step_metadata.get(node_name, {})
task_name = meta["task_name"]
queue_name = meta.get("queue") or "default"
max_retries = int_or(meta.get("max_retries"), 3)
timeout_ms = int_or(meta.get("timeout_ms"), 300_000)
priority = int_or(meta.get("priority"), 0)

try:
job_id = tracker._queue._inner.create_deferred_job(
run_id,
node_name,
payload,
task_name,
queue_name,
max_retries,
timeout_ms,
priority,
)
except (RuntimeError, ValueError):
logger.exception("create_deferred_job failed for %s", node_name)
return
with tracker._state_lock:
tracker._job_to_run[job_id] = run_id
config.deferred_nodes.discard(node_name)
Loading
Loading