From b330186b97b734efcc5608d114fc007358ebac64 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 18 May 2026 15:47:34 -0500 Subject: [PATCH 01/15] init --- .../providers/databricks/exceptions.py | 16 + .../databricks/operators/databricks.py | 241 +++++++++++- .../operators/databricks_workflow.py | 270 ++++++++++++- .../databricks/triggers/databricks.py | 359 +++++++++++++++++- .../providers/databricks/utils/databricks.py | 49 +++ .../databricks/operators/test_databricks.py | 184 +++++++++ .../operators/test_databricks_workflow.py | 170 ++++++++- .../databricks/triggers/test_databricks.py | 276 ++++++++++++++ .../unit/databricks/utils/test_databricks.py | 38 ++ 9 files changed, 1578 insertions(+), 25 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/exceptions.py b/providers/databricks/src/airflow/providers/databricks/exceptions.py index f384552a34a6e..424ee8b7e45b3 100644 --- a/providers/databricks/src/airflow/providers/databricks/exceptions.py +++ b/providers/databricks/src/airflow/providers/databricks/exceptions.py @@ -30,3 +30,19 @@ class DatabricksSqlExecutionError(AirflowException): class DatabricksSqlExecutionTimeout(DatabricksSqlExecutionError): """Raised when a sql execution times out.""" + + +class DatabricksWorkflowRepairError(AirflowException): + """Raised when Databricks Workflow repair coordination fails.""" + + +class DatabricksWorkflowRepairMetadataError(DatabricksWorkflowRepairError): + """Raised when workflow repair metadata is missing or invalid.""" + + +class DatabricksWorkflowRepairBudgetExhausted(DatabricksWorkflowRepairError): + """Raised when a Databricks Workflow run fails after exhausting repair attempts.""" + + +class DatabricksWorkflowRepairTriggerError(DatabricksWorkflowRepairError): + """Raised when a Databricks Workflow repair trigger emits an invalid event.""" diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 9898993d4147e..4bdb5702b9c00 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -27,6 +27,10 @@ from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink, XCom, conf +from airflow.providers.databricks.exceptions import ( + DatabricksWorkflowRepairBudgetExhausted, + DatabricksWorkflowRepairTriggerError, +) from airflow.providers.databricks.hooks.databricks import ( DatabricksHook, RunLifeCycleState, @@ -43,9 +47,11 @@ ) from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, + DatabricksWorkflowRepairWaitTrigger, ) from airflow.providers.databricks.utils.databricks import ( extract_failed_task_errors, + find_new_workflow_task_attempt, normalise_json_content, validate_trigger_event, ) @@ -1491,19 +1497,64 @@ def monitor_databricks_job(self) -> None: self.databricks_task_key, run_state.life_cycle_state, ) - if self.deferrable and not run_state.is_terminal: - self.defer( - trigger=DatabricksExecutionTrigger( - run_id=current_task_run_id, - databricks_conn_id=self.databricks_conn_id, - polling_period_seconds=self.polling_period_seconds, - retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay, - retry_args=self.databricks_retry_args, - caller=self.caller, - ), - method_name=DEFER_METHOD_NAME, - ) + if self.deferrable: + if not run_state.is_terminal: + self.defer( + trigger=DatabricksExecutionTrigger( + run_id=current_task_run_id, + databricks_conn_id=self.databricks_conn_id, + polling_period_seconds=self.polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ), + method_name=DEFER_METHOD_NAME, + ) + elif not run_state.is_successful: + tg = self._workflow_task_group_with_repair() + if tg is not None: + self._defer_to_workflow_repair_wait( + original_sub_run_id=current_task_run_id, + original_start_time=run.get("start_time"), + tg=tg, + ) + else: + tg = self._workflow_task_group_with_repair() + if tg is not None: + while True: + while not run_state.is_terminal: + time.sleep(self.polling_period_seconds) + run = self._hook.get_run(current_task_run_id) + run_state = RunState(**run["state"]) + self.log.info( + "Current state of the databricks task %s is %s", + self.databricks_task_key, + run_state.life_cycle_state, + ) + if run_state.is_successful: + break + new_sub_run_id = self._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=current_task_run_id, + original_start_time=run.get("start_time"), + tg=tg, + ) + if new_sub_run_id is None: + break + self.log.info( + "Workflow coordinator produced a new attempt for task_key %s (sub-run %s).", + self.databricks_task_key, + new_sub_run_id, + ) + current_task_run_id = new_sub_run_id + run = self._hook.get_run(current_task_run_id) + run_state = RunState(**run["state"]) + self.log.info( + "Current state of the databricks task %s is %s", + self.databricks_task_key, + run_state.life_cycle_state, + ) + while not run_state.is_terminal: time.sleep(self.polling_period_seconds) run = self._hook.get_run(current_task_run_id) @@ -1523,13 +1574,7 @@ def monitor_databricks_job(self) -> None: def execute(self, context: Context) -> None: """Execute the operator. Launch the job and monitor it if wait_for_termination is set to True.""" if self._databricks_workflow_task_group: - # If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched. - if not self.workflow_run_metadata: - launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch")) - self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id) - workflow_run_metadata = WorkflowRunMetadata(**self.workflow_run_metadata) - self.databricks_run_id = workflow_run_metadata.run_id - self.databricks_conn_id = workflow_run_metadata.conn_id + workflow_run_metadata = self._resolve_workflow_run_metadata(context) # Store operator links in XCom for Airflow 3 compatibility if AIRFLOW_V_3_0_PLUS: @@ -1544,11 +1589,167 @@ def execute(self, context: Context) -> None: if self.wait_for_termination: self.monitor_databricks_job() + def _resolve_workflow_run_metadata(self, context: Context | dict | None) -> WorkflowRunMetadata: + """ + Populate ``databricks_run_id`` / ``databricks_conn_id`` from ``workflow_run_metadata``. + + Resolves both the standard ``execute`` path and the deferrable-resume path: when a task + resumes via ``execute_complete``, the operator is freshly re-instantiated, so any + attributes set during ``execute`` (including ``databricks_run_id``) are lost. The + templated ``workflow_run_metadata`` field is rendered from the upstream ``.launch`` + task's XCom on every task-instance run, so it is the canonical source for the parent + workflow run id across both entry points. + """ + if not self.workflow_run_metadata and context is not None: + launch_task_id = next((task for task in self.upstream_task_ids if task.endswith(".launch")), None) + ti = context.get("ti") if isinstance(context, dict) else context["ti"] + if launch_task_id is not None and ti is not None: + self.workflow_run_metadata = ti.xcom_pull(task_ids=launch_task_id) + if not self.workflow_run_metadata: + raise ValueError("workflow_run_metadata is required to resolve the parent workflow run") + workflow_run_metadata = WorkflowRunMetadata(**self.workflow_run_metadata) + self.databricks_run_id = workflow_run_metadata.run_id + self.databricks_conn_id = workflow_run_metadata.conn_id + return workflow_run_metadata + def execute_complete(self, context: dict | None, event: dict) -> None: run_state = RunState.from_json(event["run_state"]) errors = event.get("errors", []) + + if not run_state.is_successful: + tg = self._workflow_task_group_with_repair() + if tg is not None: + self._resolve_workflow_run_metadata(context) + self._defer_to_workflow_repair_wait( + original_sub_run_id=event["run_id"], + original_start_time=event.get("run_start_time"), + tg=tg, + ) + self._handle_terminal_run_state(run_state, errors) + def _workflow_task_group_with_repair(self) -> DatabricksWorkflowTaskGroup | None: + if not AIRFLOW_V_3_0_PLUS: + return None + tg = self._databricks_workflow_task_group + if tg is None or getattr(tg, "max_full_run_repairs", 0) <= 0: + return None + return tg + + def _sync_wait_for_new_sub_run_attempt( + self, + original_sub_run_id: int, + original_start_time: int | None, + tg: DatabricksWorkflowTaskGroup, + ) -> int | None: + """ + Sync equivalent of :class:`DatabricksWorkflowRepairWaitTrigger`. + + Polls the parent run for a new attempt of ``self.databricks_task_key`` after a + sub-run reaches terminal failure, then lets the caller switch to that attempt and + continue polling. Returns the new sub-run id, or ``None`` if the parent run + terminates without producing a new attempt within the grace window. + """ + self.log.info( + "Sub-run %s for task_key %s reached terminal failure; waiting for a repair " + "attempt issued by the workflow coordinator.", + original_sub_run_id, + self.databricks_task_key, + ) + polling_period_seconds = tg.repair_polling_period_seconds + terminal_grace_polls = 3 + terminal_observations = 0 + while True: + run_info = self._hook.get_run(self.databricks_run_id) # type: ignore[arg-type] + parent_run_state = RunState(**run_info["state"]) + tasks = run_info.get("tasks", []) + new_attempt = find_new_workflow_task_attempt( + tasks=tasks, + task_key=self.databricks_task_key, + original_sub_run_id=original_sub_run_id, + original_start_time=original_start_time, + ) + if new_attempt is not None: + return new_attempt["run_id"] + if parent_run_state.is_terminal: + terminal_observations += 1 + if terminal_observations >= terminal_grace_polls: + self.log.info( + "Parent run %s reached terminal state %s without a new attempt for " + "task_key %s after %s grace polls.", + self.databricks_run_id, + parent_run_state.result_state, + self.databricks_task_key, + terminal_grace_polls, + ) + return None + else: + terminal_observations = 0 + time.sleep(polling_period_seconds) + + def _defer_to_workflow_repair_wait( + self, + original_sub_run_id: int, + original_start_time: int | None, + tg: DatabricksWorkflowTaskGroup, + ) -> None: + self.log.info( + "Sub-run %s for task_key %s reached terminal failure; deferring to wait for a repair " + "attempt issued by the workflow coordinator.", + original_sub_run_id, + self.databricks_task_key, + ) + self.defer( + trigger=DatabricksWorkflowRepairWaitTrigger( + run_id=self.databricks_run_id, # type: ignore[arg-type] + databricks_conn_id=self.databricks_conn_id, + databricks_task_key=self.databricks_task_key, + original_sub_run_id=original_sub_run_id, + original_start_time=original_start_time, + polling_period_seconds=tg.repair_polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ), + method_name="execute_complete_after_repair_wait", + ) + + def execute_complete_after_repair_wait(self, context: dict | None, event: dict) -> None: + status = event.get("status") + if status == "new_attempt": + new_sub_run_id = event["new_sub_run_id"] + self.log.info( + "Workflow coordinator produced a new attempt for task_key %s (sub-run %s); " + "deferring on a fresh DatabricksExecutionTrigger to monitor it.", + self.databricks_task_key, + new_sub_run_id, + ) + self.defer( + trigger=DatabricksExecutionTrigger( + run_id=new_sub_run_id, + databricks_conn_id=self.databricks_conn_id, + polling_period_seconds=self.polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ), + method_name=DEFER_METHOD_NAME, + ) + elif status == "parent_failed": + parent_state = RunState.from_json(event["parent_run_state"]) + raise DatabricksWorkflowRepairBudgetExhausted( + f"Databricks workflow run {event.get('parent_run_id')} reached terminal failure " + f"({parent_state.result_state}) without producing a new attempt for task_key " + f"{self.databricks_task_key!r}; repair budget is exhausted or the coordinator " + f"did not issue a repair." + ) + else: + raise DatabricksWorkflowRepairTriggerError( + f"DatabricksWorkflowRepairWaitTrigger emitted unexpected status {status!r}: {event}" + ) + class DatabricksNotebookOperator(DatabricksTaskBaseOperator): """ diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 779c2fc9f1528..2e6c5a1c66b72 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -19,19 +19,27 @@ import json import time +import warnings from dataclasses import dataclass from functools import cached_property from typing import TYPE_CHECKING, Any from mergedeep import merge -from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, TaskGroup +from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, TaskGroup, conf +from airflow.providers.databricks.exceptions import ( + DatabricksWorkflowRepairBudgetExhausted, + DatabricksWorkflowRepairMetadataError, + DatabricksWorkflowRepairTriggerError, +) from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState from airflow.providers.databricks.plugins.databricks_workflow import ( WorkflowJobRepairAllFailedLink, WorkflowJobRunLink, store_databricks_job_run_link, ) +from airflow.providers.databricks.triggers.databricks import DatabricksWorkflowRepairCoordinatorTrigger +from airflow.providers.databricks.utils.databricks import build_repair_run_json, extract_failed_task_errors from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -300,6 +308,224 @@ def on_kill(self) -> None: ) +class _DatabricksFullRunRepairCoordinatorOperator(BaseOperator): + """ + Watch a Databricks Workflow run and trigger ``rerun_all_failed_tasks`` repairs. + + Runs as a sibling of the downstream Databricks task monitors inside a + :class:`DatabricksWorkflowTaskGroup`. The ``launch`` task remains responsible for creating the + job and returning metadata immediately so downstream tasks can fan out; this operator owns the + long-lived defer cycle that watches the whole run and issues one repair call per failure batch, + so the original job cluster is reused. + + Downstream task monitors observe repairs independently by deferring on + :class:`~airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger` + when their sub-run hits a terminal failure; that trigger polls the Databricks API for the next + attempt of the same ``task_key`` rather than reading any inter-task XCom. The coordinator's + final return value carries ``{run_id, repair_attempts, latest_repair_id}`` for any user code + that wants a post-run summary. + + :param task_id: The task id of the operator (typically ``"full_run_repair_coordinator"``). + :param databricks_conn_id: Connection id used by the coordinator trigger and repair calls. + :param launch_task_id: The full task id of the workflow ``launch`` task whose return value + carries ``{conn_id, job_id, run_id}``. + :param max_full_run_repairs: Total repair attempts allowed across the run. + :param repair_polling_period_seconds: Poll interval forwarded to the coordinator trigger. + :param databricks_retry_limit: Hook retry limit for transient API failures. + :param databricks_retry_delay: Hook retry delay (seconds). + :param databricks_retry_args: Optional ``tenacity.Retrying`` kwargs forwarded to the hook. + :param deferrable: If ``True``, watch the run by deferring on + :class:`DatabricksWorkflowRepairCoordinatorTrigger`. If ``False``, watch it via a + synchronous poll loop that runs the same state machine inline. + """ + + caller = "_DatabricksFullRunRepairCoordinatorOperator" + + def __init__( + self, + task_id: str, + databricks_conn_id: str, + launch_task_id: str, + max_full_run_repairs: int, + repair_polling_period_seconds: int = 30, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 10, + databricks_retry_args: dict[Any, Any] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + if max_full_run_repairs < 1: + raise ValueError( + f"max_full_run_repairs must be >= 1 for the workflow coordinator task, got {max_full_run_repairs}" + ) + super().__init__(task_id=task_id, **kwargs) + self.databricks_conn_id = databricks_conn_id + self.launch_task_id = launch_task_id + self.max_full_run_repairs = max_full_run_repairs + self.repair_polling_period_seconds = repair_polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args + self.deferrable = deferrable + + @cached_property + def _hook(self) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ) + + def _make_trigger( + self, + run_id: int, + repair_attempts: int, + latest_repair_id: int | None, + ) -> DatabricksWorkflowRepairCoordinatorTrigger: + return DatabricksWorkflowRepairCoordinatorTrigger( + run_id=run_id, + databricks_conn_id=self.databricks_conn_id, + max_full_run_repairs=self.max_full_run_repairs, + repair_attempts=repair_attempts, + latest_repair_id=latest_repair_id, + polling_period_seconds=self.repair_polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ) + + def execute(self, context: Context) -> Any: + launch_value = context["ti"].xcom_pull(task_ids=self.launch_task_id) + if not launch_value: + raise DatabricksWorkflowRepairMetadataError( + f"Launch task {self.launch_task_id!r} did not publish workflow run metadata; " + "cannot coordinate repairs." + ) + metadata = WorkflowRunMetadata(**launch_value) + + if self.deferrable: + self.defer( + trigger=self._make_trigger( + run_id=metadata.run_id, + repair_attempts=0, + latest_repair_id=None, + ), + method_name="execute_complete", + ) + return self._run_sync(metadata.run_id) + + def _run_sync(self, run_id: int) -> dict[str, Any]: + """Sync equivalent of :class:`DatabricksWorkflowRepairCoordinatorTrigger`'s state machine.""" + repair_attempts = 0 + latest_repair_id: int | None = None + while True: + run_state = self._hook.get_run_state(run_id) + while not run_state.is_terminal: + self.log.info( + "Databricks run %s in state %s. Sleeping for %s seconds.", + run_id, + run_state, + self.repair_polling_period_seconds, + ) + time.sleep(self.repair_polling_period_seconds) + run_state = self._hook.get_run_state(run_id) + + if run_state.is_successful: + self.log.info( + "Databricks workflow run %s completed (repair_attempts=%s).", + run_id, + repair_attempts, + ) + return { + "run_id": run_id, + "repair_attempts": repair_attempts, + "latest_repair_id": latest_repair_id, + } + + run_info = self._hook.get_run(run_id) + errors = extract_failed_task_errors(self._hook, run_info, run_state) + + if repair_attempts >= self.max_full_run_repairs: + raise DatabricksWorkflowRepairBudgetExhausted( + f"Databricks workflow run {run_id} failed after {repair_attempts} repair " + f"attempt(s); repair budget exhausted (max_full_run_repairs={self.max_full_run_repairs}). " + f"Errors: {errors}" + ) + + self.log.info( + "Databricks run %s reached terminal failure state %s. Repairing all failed " + "tasks (attempt %s of %s, latest_repair_id=%s).", + run_id, + run_state.result_state, + repair_attempts + 1, + self.max_full_run_repairs, + latest_repair_id, + ) + repair_json = build_repair_run_json( + run_id=run_id, + latest_repair_id=latest_repair_id, + overriding_parameters=run_info.get("overriding_parameters"), + ) + latest_repair_id = self._hook.repair_run(repair_json) + repair_attempts += 1 + self.log.info( + "Databricks repair_run accepted for run %s; new repair_id=%s.", + run_id, + latest_repair_id, + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + status = event.get("status") + run_id = event["run_id"] + repair_attempts = event["repair_attempts"] + latest_repair_id = event.get("latest_repair_id") + + if status == "completed": + self.log.info( + "Databricks workflow run %s completed (repair_attempts=%s).", + run_id, + repair_attempts, + ) + return { + "run_id": run_id, + "repair_attempts": repair_attempts, + "latest_repair_id": latest_repair_id, + } + + if status == "repaired": + self.log.info( + "Databricks workflow run %s repaired (repair_attempts=%s, latest_repair_id=%s); " + "re-deferring to monitor the repaired run.", + run_id, + repair_attempts, + latest_repair_id, + ) + self.defer( + trigger=self._make_trigger( + run_id=run_id, + repair_attempts=repair_attempts, + latest_repair_id=latest_repair_id, + ), + method_name="execute_complete", + ) + return None + + if status == "failed": + errors = event.get("errors", []) + raise DatabricksWorkflowRepairBudgetExhausted( + f"Databricks workflow run {run_id} failed after {repair_attempts} repair " + f"attempt(s); repair budget exhausted (max_full_run_repairs={self.max_full_run_repairs}). " + f"Errors: {errors}" + ) + + raise DatabricksWorkflowRepairTriggerError( + f"DatabricksWorkflowRepairCoordinatorTrigger emitted unexpected status {status!r}: {event}" + ) + + class DatabricksWorkflowTaskGroup(TaskGroup): """ A task group that takes a list of tasks and creates a databricks workflow. @@ -338,6 +564,12 @@ class DatabricksWorkflowTaskGroup(TaskGroup): all python tasks in the workflow. :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters will be passed to all spark submit tasks. + :param max_full_run_repairs: Maximum number of automatic ``rerun_all_failed_tasks`` repair attempts to + issue against the Databricks run when downstream tasks fail. Each repair reuses the + original job cluster. Set to ``0`` to disable auto-repair (current behavior). Only takes + effect on Airflow 3+; ignored on Airflow 2.x. Defaults to ``0``. + :param repair_polling_period_seconds: How often the repair coordinator polls the + Databricks run state. Only used when ``max_full_run_repairs > 0``. """ is_databricks = True @@ -355,8 +587,12 @@ def __init__( notebook_params: dict | None = None, python_params: list | None = None, spark_submit_params: list | None = None, + max_full_run_repairs: int = 0, + repair_polling_period_seconds: int = 30, **kwargs, ): + if max_full_run_repairs < 0: + raise ValueError(f"max_full_run_repairs must be >= 0, got {max_full_run_repairs}") self.databricks_conn_id = databricks_conn_id self.access_control_list = access_control_list self.existing_clusters = existing_clusters or [] @@ -368,6 +604,8 @@ def __init__( self.notebook_params = notebook_params or {} self.python_params = python_params or [] self.spark_submit_params = spark_submit_params or [] + self.max_full_run_repairs = max_full_run_repairs + self.repair_polling_period_seconds = repair_polling_period_seconds super().__init__(**kwargs) def __exit__( @@ -403,11 +641,41 @@ def __exit__( f"Task {task.task_id} does not support conversion to databricks workflow task." ) + if AIRFLOW_V_3_0_PLUS and self.max_full_run_repairs > 0: + task_retries = getattr(task, "retries", 0) + if isinstance(task_retries, int) and task_retries > 0: + warnings.warn( + f"Task {task.task_id!r} in DatabricksWorkflowTaskGroup " + f"{self.group_id!r} has retries={task_retries} while " + f"max_full_run_repairs={self.max_full_run_repairs}. Databricks-side repair supersedes " + "task-level retries for sub-run failures: a failed sub-run defers the " + "monitor on the repair-wait trigger and resumes in the same Airflow " + "attempt. Retries on the monitor will not trigger additional repairs " + "and only add cost on non-repair-related transient failures. Consider " + "setting retries=0 on workflow tasks when max_full_run_repairs > 0.", + UserWarning, + stacklevel=2, + ) task.workflow_run_metadata = create_databricks_workflow_task.output create_databricks_workflow_task.relevant_upstreams.append(task.task_id) create_databricks_workflow_task.add_task(task.task_id, task) for root_task in roots: root_task.set_upstream(create_databricks_workflow_task) + + if AIRFLOW_V_3_0_PLUS and self.max_full_run_repairs > 0: + repair_coordinator_task = _DatabricksFullRunRepairCoordinatorOperator( + dag=self.dag, + task_group=self, + task_id="full_run_repair_coordinator", + databricks_conn_id=self.databricks_conn_id, + launch_task_id=create_databricks_workflow_task.task_id, + max_full_run_repairs=self.max_full_run_repairs, + repair_polling_period_seconds=self.repair_polling_period_seconds, + # Retrying the coordinator would re-enter execute() with repair_attempts=0 + # and start the budget over. + retries=0, + ) + repair_coordinator_task.set_upstream(create_databricks_workflow_task) finally: super().__exit__(_type, _value, _tb) diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 2bb626b911418..1f3b6a7f24ec5 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -21,8 +21,12 @@ import time from typing import Any -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.providers.databricks.utils.databricks import extract_failed_task_errors_async +from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState +from airflow.providers.databricks.utils.databricks import ( + build_repair_run_json, + extract_failed_task_errors_async, + find_new_workflow_task_attempt, +) from airflow.providers.databricks.utils.retry import validate_deferrable_databricks_retry_args from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -121,6 +125,7 @@ async def run(self): "run_id": self.run_id, "run_page_url": self.run_page_url, "run_state": run_state.to_json(), + "run_start_time": run_info.get("start_time"), "repair_run": self.repair_run, "errors": failed_tasks, } @@ -128,6 +133,356 @@ async def run(self): return +class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): + """ + Coordinate whole-run polling and ``rerun_all_failed_tasks`` repair for a Databricks Workflow run. + + Owned by the ``coordinator`` sibling task that + :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` + injects when ``max_full_run_repairs > 0`` on Airflow 3+. Keeps a single Databricks job run alive across + repair attempts so the same job cluster is reused. Each defer/resume cycle of the coordinator + task corresponds to one iteration: + + 1. Poll the run until it reaches a terminal state. + 2. On terminal success, yield ``status="completed"``. + 3. On terminal failure with repair budget remaining, call + :meth:`~airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run` with + ``rerun_all_failed_tasks=True`` and yield ``status="repaired"`` along with the new + ``latest_repair_id`` and bumped ``repair_attempts``; the coordinator task then re-defers on + a fresh trigger instance with the new state. Downstream task monitors observe the new sub-run + attempt by polling the Databricks API directly (via + :class:`DatabricksWorkflowRepairWaitTrigger`), not via any inter-task XCom. + 4. On terminal failure with the budget exhausted, yield ``status="failed"``. + + The Databricks ``run_id`` is stable across repair attempts; only ``latest_repair_id`` changes. + + :param run_id: The Databricks run id to coordinate. + :param databricks_conn_id: Airflow connection id for the Databricks hook. + :param max_full_run_repairs: Total repair attempts allowed for this run. + :param repair_attempts: Repair attempts already performed (defaults to 0 on the first defer). + :param latest_repair_id: Repair id of the most recent repair attempt, or ``None`` on the first + defer. Forwarded to ``repair_run`` so Databricks knows which attempt is the latest. + :param polling_period_seconds: How often to poll the run state. + :param retry_limit: Hook retry limit for transient Databricks API failures. + :param retry_delay: Hook retry delay (seconds). + :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. + :param run_page_url: The Databricks UI URL for this run, surfaced in events for logging. + :param caller: Caller label forwarded to the hook for diagnostics. + """ + + def __init__( + self, + run_id: int, + databricks_conn_id: str, + max_full_run_repairs: int, + repair_attempts: int = 0, + latest_repair_id: int | None = None, + polling_period_seconds: int = 30, + retry_limit: int = 3, + retry_delay: int = 10, + retry_args: dict[Any, Any] | None = None, + run_page_url: str | None = None, + caller: str = "DatabricksWorkflowRepairCoordinatorTrigger", + ) -> None: + super().__init__() + self.run_id = run_id + self.databricks_conn_id = databricks_conn_id + self.max_full_run_repairs = max_full_run_repairs + self.repair_attempts = repair_attempts + self.latest_repair_id = latest_repair_id + self.polling_period_seconds = polling_period_seconds + self.retry_limit = retry_limit + self.retry_delay = retry_delay + self.retry_args = retry_args + self.run_page_url = run_page_url + self.caller = caller + self.hook = DatabricksHook( + databricks_conn_id, + retry_limit=self.retry_limit, + retry_delay=self.retry_delay, + retry_args=retry_args, + caller=caller, + ) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairCoordinatorTrigger", + { + "run_id": self.run_id, + "databricks_conn_id": self.databricks_conn_id, + "max_full_run_repairs": self.max_full_run_repairs, + "repair_attempts": self.repair_attempts, + "latest_repair_id": self.latest_repair_id, + "polling_period_seconds": self.polling_period_seconds, + "retry_limit": self.retry_limit, + "retry_delay": self.retry_delay, + "retry_args": self.retry_args, + "run_page_url": self.run_page_url, + "caller": self.caller, + }, + ) + + async def on_kill(self) -> None: + """Cancel the Databricks run when the trigger is cancelled by a user action.""" + if self.run_id: + from asgiref.sync import sync_to_async + + self.log.info("Cancelling Databricks run %s.", self.run_id) + await sync_to_async(self.hook.cancel_run)(self.run_id) + + async def run(self): + from asgiref.sync import sync_to_async + + async with self.hook: + while True: + run_state = await self.hook.a_get_run_state(self.run_id) + if not run_state.is_terminal: + self.log.info( + "run-id %s in run state %s. sleeping for %s seconds", + self.run_id, + run_state, + self.polling_period_seconds, + ) + await asyncio.sleep(self.polling_period_seconds) + continue + + run_info = await self.hook.a_get_run(self.run_id) + errors = await extract_failed_task_errors_async(self.hook, run_info, run_state) + + if run_state.is_successful: + self.log.info("Databricks run %s completed successfully.", self.run_id) + yield TriggerEvent( + { + "status": "completed", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts, + "latest_repair_id": self.latest_repair_id, + "errors": errors, + } + ) + return + + if self.repair_attempts >= self.max_full_run_repairs: + self.log.info( + "Databricks run %s reached terminal failure state %s and repair budget " + "is exhausted (max_full_run_repairs=%s).", + self.run_id, + run_state.result_state, + self.max_full_run_repairs, + ) + yield TriggerEvent( + { + "status": "failed", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts, + "latest_repair_id": self.latest_repair_id, + "errors": errors, + } + ) + return + + self.log.info( + "Databricks run %s reached terminal failure state %s. Repairing all failed " + "tasks (attempt %s of %s, latest_repair_id=%s).", + self.run_id, + run_state.result_state, + self.repair_attempts + 1, + self.max_full_run_repairs, + self.latest_repair_id, + ) + + repair_json = build_repair_run_json( + run_id=self.run_id, + latest_repair_id=self.latest_repair_id, + overriding_parameters=run_info.get("overriding_parameters"), + ) + + new_repair_id = await sync_to_async(self.hook.repair_run)(repair_json) + self.log.info( + "Databricks repair_run accepted for run %s; new repair_id=%s.", + self.run_id, + new_repair_id, + ) + + yield TriggerEvent( + { + "status": "repaired", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts + 1, + "latest_repair_id": new_repair_id, + "errors": errors, + } + ) + return + + +class DatabricksWorkflowRepairWaitTrigger(BaseTrigger): + """ + Wait for the next attempt of a single Databricks Workflow task after its sub-run failed. + + Used by Databricks task monitors inside a + :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` + when ``max_full_run_repairs > 0`` on Airflow 3+. A monitor whose sub-run reaches terminal failure defers + on this trigger; the trigger polls the parent run's task list and yields when a new attempt of + the same ``task_key`` appears (issued by the sibling ``coordinator`` task via + ``rerun_all_failed_tasks``), so the monitor can then defer on a fresh + :class:`DatabricksExecutionTrigger` watching the new sub-run id. + + Each poll cycle: + + 1. If a Databricks task with our ``databricks_task_key`` exists whose ``run_id`` differs from + ``original_sub_run_id`` and whose ``start_time`` is newer, yield ``status="new_attempt"`` + with the new sub-run id. + 2. Otherwise, if the parent run is in a terminal failure state, count one "grace" observation. + After ``terminal_grace_polls`` consecutive terminal observations without a new attempt, + yield ``status="parent_failed"``. This avoids racing the coordinator: the parent run is + briefly terminal between sub-run failure and the coordinator issuing ``repair_run``. + 3. Otherwise (parent still running, or terminal but inside the grace window), sleep and poll + again. + + :param run_id: Parent workflow run id (stable across repairs). + :param databricks_conn_id: Airflow connection id for the Databricks hook. + :param databricks_task_key: The ``task_key`` of the Databricks task to watch for a new attempt. + :param original_sub_run_id: The sub-run id of the attempt that just failed; the trigger only + yields ``new_attempt`` for a sub-run id different from this one. + :param polling_period_seconds: How often to poll the parent run. + :param terminal_grace_polls: Number of consecutive terminal-with-no-new-attempt observations + required before yielding ``status="parent_failed"``. Bounds how long we wait for the + coordinator to issue a repair after observing terminal failure. + :param retry_limit: Hook retry limit for transient Databricks API failures. + :param retry_delay: Hook retry delay (seconds). + :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. + :param run_page_url: The Databricks UI URL for the parent run, surfaced in events for logging. + :param caller: Caller label forwarded to the hook for diagnostics. + """ + + def __init__( + self, + run_id: int, + databricks_conn_id: str, + databricks_task_key: str, + original_sub_run_id: int, + original_start_time: int | None = None, + polling_period_seconds: int = 30, + terminal_grace_polls: int = 3, + retry_limit: int = 3, + retry_delay: int = 10, + retry_args: dict[Any, Any] | None = None, + run_page_url: str | None = None, + caller: str = "DatabricksWorkflowRepairWaitTrigger", + ) -> None: + super().__init__() + if terminal_grace_polls < 1: + raise ValueError(f"terminal_grace_polls must be >= 1, got {terminal_grace_polls}") + self.run_id = run_id + self.databricks_conn_id = databricks_conn_id + self.databricks_task_key = databricks_task_key + self.original_sub_run_id = original_sub_run_id + self.original_start_time = original_start_time + self.polling_period_seconds = polling_period_seconds + self.terminal_grace_polls = terminal_grace_polls + self.retry_limit = retry_limit + self.retry_delay = retry_delay + self.retry_args = retry_args + self.run_page_url = run_page_url + self.caller = caller + self.hook = DatabricksHook( + databricks_conn_id, + retry_limit=self.retry_limit, + retry_delay=self.retry_delay, + retry_args=retry_args, + caller=caller, + ) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger", + { + "run_id": self.run_id, + "databricks_conn_id": self.databricks_conn_id, + "databricks_task_key": self.databricks_task_key, + "original_sub_run_id": self.original_sub_run_id, + "original_start_time": self.original_start_time, + "polling_period_seconds": self.polling_period_seconds, + "terminal_grace_polls": self.terminal_grace_polls, + "retry_limit": self.retry_limit, + "retry_delay": self.retry_delay, + "retry_args": self.retry_args, + "run_page_url": self.run_page_url, + "caller": self.caller, + }, + ) + + def _find_new_attempt(self, tasks: list[dict[str, Any]]) -> dict[str, Any] | None: + return find_new_workflow_task_attempt( + tasks=tasks, + task_key=self.databricks_task_key, + original_sub_run_id=self.original_sub_run_id, + original_start_time=self.original_start_time, + ) + + async def run(self): + terminal_observations = 0 + async with self.hook: + while True: + run_info = await self.hook.a_get_run(self.run_id) + run_state = RunState(**run_info["state"]) + tasks = run_info.get("tasks", []) + + new_attempt = self._find_new_attempt(tasks) + if new_attempt is not None: + self.log.info( + "Databricks workflow run %s produced a new attempt for task_key %s " + "(new sub-run id %s).", + self.run_id, + self.databricks_task_key, + new_attempt["run_id"], + ) + yield TriggerEvent( + { + "status": "new_attempt", + "parent_run_id": self.run_id, + "databricks_task_key": self.databricks_task_key, + "new_sub_run_id": new_attempt["run_id"], + "run_page_url": self.run_page_url, + } + ) + return + + if run_state.is_terminal and not run_state.is_successful: + terminal_observations += 1 + self.log.info( + "Databricks workflow run %s is in terminal failure state %s with no new " + "attempt for task_key %s (grace %s of %s).", + self.run_id, + run_state.result_state, + self.databricks_task_key, + terminal_observations, + self.terminal_grace_polls, + ) + if terminal_observations >= self.terminal_grace_polls: + yield TriggerEvent( + { + "status": "parent_failed", + "parent_run_id": self.run_id, + "databricks_task_key": self.databricks_task_key, + "parent_run_state": run_state.to_json(), + "run_page_url": self.run_page_url, + } + ) + return + else: + terminal_observations = 0 + + await asyncio.sleep(self.polling_period_seconds) + + class DatabricksSQLStatementExecutionTrigger(BaseTrigger): """ The trigger handles the logic of async communication with DataBricks SQL Statements API. diff --git a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py index 05b6b17710ef8..4c1f557b408be 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +from typing import Any + from airflow.providers.common.compat.sdk import AirflowException, XComArg from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState @@ -105,6 +107,53 @@ async def extract_failed_task_errors_async( return failed_tasks +def find_new_workflow_task_attempt( + tasks: list[dict[str, Any]], + task_key: str, + original_sub_run_id: int, + original_start_time: int | None, +) -> dict[str, Any] | None: + """ + Return the newest task entry matching ``task_key`` that is not the original sub-run. + + Used by the repair-wait trigger and its sync counterpart to detect a new attempt of + a Databricks Workflow task after the prior sub-run reached terminal failure. + """ + candidates = [ + task + for task in tasks + if task.get("task_key") == task_key + and task.get("run_id") != original_sub_run_id + and (original_start_time is None or task.get("start_time", 0) > original_start_time) + ] + if not candidates: + return None + return max(candidates, key=lambda task: task.get("start_time", 0)) + + +def build_repair_run_json( + run_id: int, + latest_repair_id: int | None, + overriding_parameters: Any = None, +) -> dict[str, Any]: + """ + Build the ``DatabricksHook.repair_run`` payload for ``rerun_all_failed_tasks`` repair. + + Used by the coordinator trigger and its sync counterpart to keep the repair payload + shape in lock-step. + """ + repair_json: dict[str, Any] = { + "run_id": run_id, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + } + if latest_repair_id is not None: + repair_json["latest_repair_id"] = latest_repair_id + if overriding_parameters is not None: + repair_json["overriding_parameters"] = overriding_parameters + return repair_json + + def validate_trigger_event(event: dict): """ Validate correctness of the event received from DatabricksExecutionTrigger. diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 4684b14282c4e..43786558647d6 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -49,9 +49,12 @@ from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, DatabricksSQLStatementExecutionTrigger, + DatabricksWorkflowRepairWaitTrigger, ) from airflow.providers.databricks.utils import databricks as utils +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + DATE = "2017-04-20" TASK_ID = "databricks-operator" DEFAULT_CONN_ID = "databricks_default" @@ -2988,3 +2991,184 @@ def test_user_databricks_task_key(self): expected_task_key = "test_task_key" assert expected_task_key == operator.databricks_task_key + + +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Workflow repair flow is Airflow 3+ only") +class TestExecuteCompleteWorkflowRepair: + PARENT_RUN_ID = 100 + ORIGINAL_SUB_RUN_ID = 500 + NEW_SUB_RUN_ID = 700 + + @staticmethod + def _terminal_failure_event(sub_run_id: int) -> dict[str, Any]: + return { + "run_id": sub_run_id, + "run_page_url": "https://example", + "run_state": RunState( + life_cycle_state="TERMINATED", + state_message="boom", + result_state="FAILED", + ).to_json(), + "errors": [{"task_key": "tk", "run_id": sub_run_id, "error": "boom"}], + } + + def _operator_with_workflow_tg(self, max_full_run_repairs: int) -> DatabricksNotebookOperator: + # Pass workflow_run_metadata (templated field) rather than setting + # databricks_run_id directly. Airflow re-instantiates the operator at deferrable + # resume time, so execute_complete must rederive the parent run id from the + # rendered template, not from any attribute set during execute(). + operator = DatabricksNotebookOperator( + task_id="task1", + notebook_path="path", + source="WORKSPACE", + databricks_conn_id="databricks_default", + workflow_run_metadata={ + "conn_id": "databricks_default", + "job_id": 1, + "run_id": self.PARENT_RUN_ID, + }, + ) + + # _databricks_workflow_task_group walks up `task_group` looking for is_databricks=True. + # Hand it a single-hop chain so it finds our mocked group directly. + tg = MagicMock() + tg.is_databricks = True + tg.max_full_run_repairs = max_full_run_repairs + tg.repair_polling_period_seconds = 15 + tg.task_group = None + operator.task_group = tg + return operator + + def test_execute_complete_failure_defers_on_wait_trigger_when_max_full_run_repairs_set(self): + operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete( + context=None, + event=self._terminal_failure_event(self.ORIGINAL_SUB_RUN_ID), + ) + + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairWaitTrigger) + assert trigger.run_id == self.PARENT_RUN_ID + assert trigger.databricks_task_key == operator.databricks_task_key + assert trigger.original_sub_run_id == self.ORIGINAL_SUB_RUN_ID + assert trigger.polling_period_seconds == 15 + assert exc.value.method_name == "execute_complete_after_repair_wait" + + def test_execute_complete_failure_pulls_workflow_metadata_from_xcom(self): + # Mirrors the deferrable-resume path where execute() never ran (so + # workflow_run_metadata was not pre-populated): the resolver must fall back + # to xcom_pull from the upstream .launch task to recover the parent run id. + operator = DatabricksNotebookOperator( + task_id="task1", + notebook_path="path", + source="WORKSPACE", + databricks_conn_id="databricks_default", + ) + tg = MagicMock() + tg.is_databricks = True + tg.max_full_run_repairs = 2 + tg.repair_polling_period_seconds = 15 + tg.task_group = None + operator.task_group = tg + operator.upstream_task_ids = {"workflow.launch"} + + ti = MagicMock() + ti.xcom_pull.return_value = { + "conn_id": "databricks_default", + "job_id": 1, + "run_id": self.PARENT_RUN_ID, + } + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete( + context={"ti": ti}, + event=self._terminal_failure_event(self.ORIGINAL_SUB_RUN_ID), + ) + + ti.xcom_pull.assert_called_once_with(task_ids="workflow.launch") + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairWaitTrigger) + assert trigger.run_id == self.PARENT_RUN_ID + + def test_execute_complete_after_repair_wait_new_attempt_defers_on_execution_trigger(self): + operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete_after_repair_wait( + context=None, + event={ + "status": "new_attempt", + "parent_run_id": self.PARENT_RUN_ID, + "databricks_task_key": operator.databricks_task_key, + "new_sub_run_id": self.NEW_SUB_RUN_ID, + "run_page_url": "https://example", + }, + ) + + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksExecutionTrigger) + assert trigger.run_id == self.NEW_SUB_RUN_ID + assert exc.value.method_name == "execute_complete" + + def test_execute_complete_after_repair_wait_parent_failed_raises(self): + operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + + with pytest.raises(AirflowException, match="repair budget is exhausted"): + operator.execute_complete_after_repair_wait( + context=None, + event={ + "status": "parent_failed", + "parent_run_id": self.PARENT_RUN_ID, + "databricks_task_key": operator.databricks_task_key, + "parent_run_state": RunState( + life_cycle_state="TERMINATED", + state_message=None, + result_state="FAILED", + ).to_json(), + "run_page_url": "https://example", + }, + ) + + @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") + def test_sync_wait_for_new_sub_run_attempt_returns_new_attempt(self, mock_sleep): + operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + hook = MagicMock() + operator.__dict__["_hook"] = hook + operator.databricks_run_id = self.PARENT_RUN_ID + hook.get_run.side_effect = [ + { + "state": {"life_cycle_state": "RUNNING", "result_state": None, "state_message": None}, + "tasks": [ + { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + }, + ], + }, + { + "state": {"life_cycle_state": "RUNNING", "result_state": None, "state_message": None}, + "tasks": [ + { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + }, + { + "run_id": self.NEW_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 2000, + }, + ], + }, + ] + + result = operator._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + original_start_time=1000, + tg=operator.task_group, + ) + + assert result == self.NEW_SUB_RUN_ID + mock_sleep.assert_called_once_with(15) diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 84069ee0ff7f8..85d244f3edb98 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -29,15 +29,17 @@ from airflow import DAG from airflow.models.baseoperator import BaseOperator -from airflow.providers.common.compat.sdk import AirflowException, timezone -from airflow.providers.databricks.hooks.databricks import RunLifeCycleState +from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred, timezone +from airflow.providers.databricks.hooks.databricks import RunLifeCycleState, RunState from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, WorkflowRunMetadata, _CreateDatabricksWorkflowOperator, + _DatabricksFullRunRepairCoordinatorOperator, _flatten_node, ) +from airflow.providers.databricks.triggers.databricks import DatabricksWorkflowRepairCoordinatorTrigger from airflow.providers.standard.operators.empty import EmptyOperator DEFAULT_DATE = timezone.datetime(2021, 1, 1) @@ -571,3 +573,167 @@ def test_reset_job_payload_carries_parent_depends_on(self, mock_databricks_hook) job_id, job_spec = launch_task._hook.reset_job.call_args.args assert job_id == 42 self._assert_parent_depends_on(job_spec) + + +class TestDatabricksFullRunRepairCoordinatorOperator: + LAUNCH_TASK_ID = "wf.launch" + LAUNCH_RETURN = {"conn_id": "databricks_default", "job_id": 42, "run_id": 100} + + def _make_operator( + self, + max_full_run_repairs: int = 2, + deferrable: bool = True, + ) -> _DatabricksFullRunRepairCoordinatorOperator: + return _DatabricksFullRunRepairCoordinatorOperator( + task_id="full_run_repair_coordinator", + databricks_conn_id="databricks_default", + launch_task_id=self.LAUNCH_TASK_ID, + max_full_run_repairs=max_full_run_repairs, + repair_polling_period_seconds=10, + deferrable=deferrable, + ) + + def test_execute_raises_when_launch_xcom_missing(self): + operator = self._make_operator() + ctx = {"ti": MagicMock()} + ctx["ti"].xcom_pull.return_value = None + + with pytest.raises(AirflowException, match="did not publish workflow run metadata"): + operator.execute(ctx) + + def test_execute_defers_on_coordinator_trigger(self): + operator = self._make_operator(max_full_run_repairs=3) + ctx = {"ti": MagicMock()} + ctx["ti"].xcom_pull.return_value = self.LAUNCH_RETURN + + with pytest.raises(TaskDeferred) as exc: + operator.execute(ctx) + + ctx["ti"].xcom_push.assert_not_called() + assert exc.value.method_name == "execute_complete" + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairCoordinatorTrigger) + assert trigger.run_id == self.LAUNCH_RETURN["run_id"] + assert trigger.max_full_run_repairs == 3 + assert trigger.repair_attempts == 0 + assert trigger.latest_repair_id is None + assert trigger.polling_period_seconds == 10 + + def test_execute_complete_repaired_redefers_without_xcom_push(self): + operator = self._make_operator(max_full_run_repairs=3) + ctx = {"ti": MagicMock()} + + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete( + ctx, + event={ + "status": "repaired", + "run_id": 100, + "repair_attempts": 1, + "latest_repair_id": 555, + }, + ) + + ctx["ti"].xcom_push.assert_not_called() + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairCoordinatorTrigger) + assert trigger.run_id == 100 + assert trigger.repair_attempts == 1 + assert trigger.latest_repair_id == 555 + assert trigger.max_full_run_repairs == 3 + + def test_execute_complete_failed_raises_with_errors_in_message(self): + operator = self._make_operator(max_full_run_repairs=2) + ctx = {"ti": MagicMock()} + errors = [{"task_key": "t1", "run_id": 11, "error": "boom"}] + + with pytest.raises(AirflowException) as exc: + operator.execute_complete( + ctx, + event={ + "status": "failed", + "run_id": 100, + "repair_attempts": 2, + "latest_repair_id": 999, + "errors": errors, + }, + ) + + message = str(exc.value) + assert "100" in message + assert "max_full_run_repairs=2" in message + assert "boom" in message + ctx["ti"].xcom_push.assert_not_called() + + @patch("airflow.providers.databricks.operators.databricks_workflow.time.sleep") + def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): + operator = self._make_operator(max_full_run_repairs=2, deferrable=False) + hook = MagicMock() + operator.__dict__["_hook"] = hook + hook.get_run_state.side_effect = [ + RunState("TERMINATED", "FAILED", ""), + RunState("TERMINATED", "SUCCESS", ""), + ] + hook.get_run.return_value = { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, + "tasks": [], + "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, + } + hook.repair_run.return_value = 555 + + result = operator._run_sync(run_id=100) + + hook.repair_run.assert_called_once_with( + { + "run_id": 100, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, + } + ) + assert result == {"run_id": 100, "repair_attempts": 1, "latest_repair_id": 555} + mock_sleep.assert_not_called() + + +class TestDatabricksWorkflowTaskGroupCoordinatorInjection: + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Coordinator task is only injected on Airflow 3+") + def test_max_full_run_repairs_positive_injects_coordinator_with_launch_upstream(self): + with DAG(dag_id="dwf_with_coord", schedule=None, start_date=DEFAULT_DATE): + with DatabricksWorkflowTaskGroup( + group_id="wf", + databricks_conn_id="databricks_conn", + max_full_run_repairs=2, + repair_polling_period_seconds=15, + ) as tg: + task = MagicMock(task_id="task1") + task._convert_to_databricks_workflow_task = MagicMock(return_value={}) + tg.add(task) + + coordinator = tg.children["wf.full_run_repair_coordinator"] + assert isinstance(coordinator, _DatabricksFullRunRepairCoordinatorOperator) + assert coordinator.max_full_run_repairs == 2 + assert coordinator.repair_polling_period_seconds == 15 + assert coordinator.launch_task_id == "wf.launch" + assert "wf.launch" in coordinator.upstream_task_ids + + def test_negative_max_full_run_repairs_rejected(self): + with pytest.raises(ValueError, match="max_full_run_repairs must be >= 0"): + DatabricksWorkflowTaskGroup( + group_id="wf_invalid", + databricks_conn_id="databricks_conn", + max_full_run_repairs=-1, + ) + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Repair behavior only applies on Airflow 3+") + def test_warns_when_user_task_has_retries_and_max_full_run_repairs_positive(self): + with DAG(dag_id="dwf_warn_retries", schedule=None, start_date=DEFAULT_DATE): + tg = DatabricksWorkflowTaskGroup( + group_id="wf", + databricks_conn_id="databricks_conn", + max_full_run_repairs=1, + ) + tg.__enter__() + task = EmptyOperator(task_id="task1", retries=2) + task._convert_to_databricks_workflow_task = MagicMock(return_value={}) + with pytest.warns(UserWarning, match=r"retries=2 while max_full_run_repairs=1"): + tg.__exit__(None, None, None) diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 8854eb03fb5bc..12b0b4c60e691 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -27,6 +27,8 @@ from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, DatabricksSQLStatementExecutionTrigger, + DatabricksWorkflowRepairCoordinatorTrigger, + DatabricksWorkflowRepairWaitTrigger, ) from airflow.triggers.base import TriggerEvent @@ -228,6 +230,7 @@ async def test_run_return_success( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS" ).to_json(), "run_page_url": RUN_PAGE_URL, + "run_start_time": None, "repair_run": False, "errors": [], } @@ -259,6 +262,7 @@ async def test_run_return_failure( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED" ).to_json(), "run_page_url": RUN_PAGE_URL, + "run_start_time": None, "repair_run": False, "errors": [ {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, @@ -299,6 +303,7 @@ async def test_sleep_between_retries( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS" ).to_json(), "run_page_url": RUN_PAGE_URL, + "run_start_time": None, "repair_run": False, "errors": [], } @@ -432,3 +437,274 @@ async def test_sleep_between_retries(self, mock_a_get_sql_statement_state, mock_ async def test_on_kill_cancels_statement(self, mock_cancel_sql_statement): await self.trigger.on_kill() mock_cancel_sql_statement.assert_called_once_with(STATEMENT_ID) + + +class TestDatabricksWorkflowRepairCoordinatorTrigger: + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + login=LOGIN, + password=PASSWORD, + extra=None, + ) + ) + + def _make_trigger( + self, + max_full_run_repairs: int = 2, + repair_attempts: int = 0, + latest_repair_id: int | None = None, + ) -> DatabricksWorkflowRepairCoordinatorTrigger: + return DatabricksWorkflowRepairCoordinatorTrigger( + run_id=RUN_ID, + databricks_conn_id=DEFAULT_CONN_ID, + max_full_run_repairs=max_full_run_repairs, + repair_attempts=repair_attempts, + latest_repair_id=latest_repair_id, + polling_period_seconds=POLLING_INTERVAL_SECONDS, + run_page_url=RUN_PAGE_URL, + ) + + def test_serialize_round_trips_state(self): + trigger = self._make_trigger(max_full_run_repairs=3, repair_attempts=1, latest_repair_id=42) + path, kwargs = trigger.serialize() + + assert ( + path + == "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairCoordinatorTrigger" + ) + assert kwargs == { + "run_id": RUN_ID, + "databricks_conn_id": DEFAULT_CONN_ID, + "max_full_run_repairs": 3, + "repair_attempts": 1, + "latest_repair_id": 42, + "polling_period_seconds": POLLING_INTERVAL_SECONDS, + "retry_limit": RETRY_LIMIT, + "retry_delay": RETRY_DELAY, + "retry_args": None, + "run_page_url": RUN_PAGE_URL, + "caller": "DatabricksWorkflowRepairCoordinatorTrigger", + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_emits_completed_when_run_succeeds(self, mock_get_run_state, mock_get_run): + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="SUCCESS", + ) + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED + + trigger = self._make_trigger(max_full_run_repairs=2, repair_attempts=0, latest_repair_id=None) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "completed" + assert events[0].payload["run_id"] == RUN_ID + assert events[0].payload["repair_attempts"] == 0 + assert events[0].payload["latest_repair_id"] is None + assert events[0].payload["errors"] == [] + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_first_failure_within_budget_calls_repair_and_emits_repaired( + self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run + ): + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ) + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + mock_repair_run.return_value = 101 + + trigger = self._make_trigger(max_full_run_repairs=2, repair_attempts=0, latest_repair_id=None) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "repaired" + assert events[0].payload["run_id"] == RUN_ID + assert events[0].payload["repair_attempts"] == 1 + assert events[0].payload["latest_repair_id"] == 101 + assert events[0].payload["errors"] == [ + {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, + {"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE}, + ] + + mock_repair_run.assert_called_once() + repair_json = mock_repair_run.call_args.args[0] + assert repair_json["run_id"] == RUN_ID + assert repair_json["rerun_all_failed_tasks"] is True + # First repair: latest_repair_id was None, so the field must be omitted from the payload + assert "latest_repair_id" not in repair_json + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_emits_failed_when_budget_exhausted( + self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run + ): + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ) + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + + trigger = self._make_trigger(max_full_run_repairs=2, repair_attempts=2, latest_repair_id=202) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "failed" + assert events[0].payload["repair_attempts"] == 2 + assert events[0].payload["latest_repair_id"] == 202 + assert events[0].payload["errors"] == [ + {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, + {"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE}, + ] + mock_repair_run.assert_not_called() + + +class TestDatabricksWorkflowRepairWaitTrigger: + PARENT_RUN_ID = 100 + TASK_KEY = "monitored_task" + ORIGINAL_SUB_RUN_ID = 500 + NEW_SUB_RUN_ID = 700 + + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + login=LOGIN, + password=PASSWORD, + extra=None, + ) + ) + + def _make_trigger( + self, + terminal_grace_polls: int = 3, + ) -> DatabricksWorkflowRepairWaitTrigger: + return DatabricksWorkflowRepairWaitTrigger( + run_id=self.PARENT_RUN_ID, + databricks_conn_id=DEFAULT_CONN_ID, + databricks_task_key=self.TASK_KEY, + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + polling_period_seconds=POLLING_INTERVAL_SECONDS, + terminal_grace_polls=terminal_grace_polls, + run_page_url=RUN_PAGE_URL, + ) + + def _run_payload( + self, + result_state: str | None, + life_cycle_state: str = LIFE_CYCLE_STATE_TERMINATED, + tasks: list[dict] | None = None, + ) -> dict: + return { + "run_page_url": RUN_PAGE_URL, + "state": { + "life_cycle_state": life_cycle_state, + "state_message": None, + "result_state": result_state, + }, + "tasks": tasks or [], + } + + def test_serialize_round_trips_state(self): + trigger = self._make_trigger(terminal_grace_polls=5) + path, kwargs = trigger.serialize() + + assert path == "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger" + assert kwargs == { + "run_id": self.PARENT_RUN_ID, + "databricks_conn_id": DEFAULT_CONN_ID, + "databricks_task_key": self.TASK_KEY, + "original_sub_run_id": self.ORIGINAL_SUB_RUN_ID, + "original_start_time": None, + "polling_period_seconds": POLLING_INTERVAL_SECONDS, + "terminal_grace_polls": 5, + "retry_limit": RETRY_LIMIT, + "retry_delay": RETRY_DELAY, + "retry_args": None, + "run_page_url": RUN_PAGE_URL, + "caller": "DatabricksWorkflowRepairWaitTrigger", + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_emits_new_attempt_when_new_sub_run_appears(self, mock_get_run): + mock_get_run.return_value = self._run_payload( + result_state=None, + life_cycle_state="RUNNING", + tasks=[ + { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": self.TASK_KEY, + "start_time": 1000, + }, + { + "run_id": self.NEW_SUB_RUN_ID, + "task_key": self.TASK_KEY, + "start_time": 2000, + }, + ], + ) + + trigger = self._make_trigger() + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload == { + "status": "new_attempt", + "parent_run_id": self.PARENT_RUN_ID, + "databricks_task_key": self.TASK_KEY, + "new_sub_run_id": self.NEW_SUB_RUN_ID, + "run_page_url": RUN_PAGE_URL, + } + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_emits_parent_failed_after_grace_polls(self, mock_get_run, mock_sleep): + terminal_payload = self._run_payload( + result_state="FAILED", + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + tasks=[ + { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": self.TASK_KEY, + "start_time": 1000, + }, + ], + ) + mock_get_run.return_value = terminal_payload + + trigger = self._make_trigger(terminal_grace_polls=3) + events = [event async for event in trigger.run()] + + assert mock_get_run.call_count == 3 + assert mock_sleep.call_count == 2 + assert len(events) == 1 + payload = events[0].payload + assert payload["status"] == "parent_failed" + assert payload["parent_run_id"] == self.PARENT_RUN_ID + assert payload["databricks_task_key"] == self.TASK_KEY + assert RunState.from_json(payload["parent_run_state"]).result_state == "FAILED" diff --git a/providers/databricks/tests/unit/databricks/utils/test_databricks.py b/providers/databricks/tests/unit/databricks/utils/test_databricks.py index 33716879713ed..d09388b88495e 100644 --- a/providers/databricks/tests/unit/databricks/utils/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/utils/test_databricks.py @@ -25,8 +25,10 @@ from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.utils.databricks import ( + build_repair_run_json, extract_failed_task_errors, extract_failed_task_errors_async, + find_new_workflow_task_attempt, normalise_json_content, validate_trigger_event, ) @@ -95,6 +97,42 @@ def test_validate_trigger_event_failure(self): with pytest.raises(AirflowException): validate_trigger_event(event) + def test_find_new_workflow_task_attempt_picks_newest_matching_attempt(self): + tasks = [ + {"run_id": TASK_RUN_ID_1, "task_key": TASK_KEY_1, "start_time": 1000}, + {"run_id": 201, "task_key": TASK_KEY_1, "start_time": 1500}, + {"run_id": 202, "task_key": TASK_KEY_1, "start_time": 2500}, + {"run_id": 203, "task_key": TASK_KEY_2, "start_time": 3000}, + ] + + result = find_new_workflow_task_attempt( + tasks=tasks, + task_key=TASK_KEY_1, + original_sub_run_id=TASK_RUN_ID_1, + original_start_time=1000, + ) + + assert result == {"run_id": 202, "task_key": TASK_KEY_1, "start_time": 2500} + + def test_build_repair_run_json_includes_optional_fields_only_when_present(self): + assert build_repair_run_json(run_id=RUN_ID, latest_repair_id=None) == { + "run_id": RUN_ID, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + } + + assert build_repair_run_json( + run_id=RUN_ID, + latest_repair_id=42, + overriding_parameters={"notebook_params": {"date": "2024-01-01"}}, + ) == { + "run_id": RUN_ID, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + "latest_repair_id": 42, + "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, + } + class TestExtractFailedTaskErrors: """Test cases for the extract_failed_task_errors utility function (synchronous version)""" From 729c868026235c08bdca9182726877bf109f3a15 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 18 May 2026 15:53:17 -0500 Subject: [PATCH 02/15] rename cordinator, update docstring --- .../operators/databricks_workflow.py | 28 +++++++++---------- .../operators/test_databricks_workflow.py | 4 +-- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 2e6c5a1c66b72..c53d0a9d2fdd2 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -313,24 +313,22 @@ class _DatabricksFullRunRepairCoordinatorOperator(BaseOperator): Watch a Databricks Workflow run and trigger ``rerun_all_failed_tasks`` repairs. Runs as a sibling of the downstream Databricks task monitors inside a - :class:`DatabricksWorkflowTaskGroup`. The ``launch`` task remains responsible for creating the - job and returning metadata immediately so downstream tasks can fan out; this operator owns the - long-lived defer cycle that watches the whole run and issues one repair call per failure batch, - so the original job cluster is reused. - - Downstream task monitors observe repairs independently by deferring on - :class:`~airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger` - when their sub-run hits a terminal failure; that trigger polls the Databricks API for the next - attempt of the same ``task_key`` rather than reading any inter-task XCom. The coordinator's - final return value carries ``{run_id, repair_attempts, latest_repair_id}`` for any user code - that wants a post-run summary. - - :param task_id: The task id of the operator (typically ``"full_run_repair_coordinator"``). + :class:`DatabricksWorkflowTaskGroup`. The ``launch`` task creates or resets the job, starts the + run, and publishes ``{conn_id, job_id, run_id}`` so monitors can fan out. This operator then + owns the parent-run repair budget, waits for terminal run states, and issues one repair call per + failed batch. + + Downstream task monitors observe repairs independently from the Databricks API when their + sub-run fails; they do not share repair state through XCom. The coordinator's final return + value carries ``{run_id, repair_attempts, latest_repair_id}`` for any user code that wants a + post-run summary. + + :param task_id: The task id of the operator (typically ``"repair_coordinator"``). :param databricks_conn_id: Connection id used by the coordinator trigger and repair calls. :param launch_task_id: The full task id of the workflow ``launch`` task whose return value carries ``{conn_id, job_id, run_id}``. :param max_full_run_repairs: Total repair attempts allowed across the run. - :param repair_polling_period_seconds: Poll interval forwarded to the coordinator trigger. + :param repair_polling_period_seconds: Poll interval used by the trigger or sync poll loop. :param databricks_retry_limit: Hook retry limit for transient API failures. :param databricks_retry_delay: Hook retry delay (seconds). :param databricks_retry_args: Optional ``tenacity.Retrying`` kwargs forwarded to the hook. @@ -667,7 +665,7 @@ def __exit__( repair_coordinator_task = _DatabricksFullRunRepairCoordinatorOperator( dag=self.dag, task_group=self, - task_id="full_run_repair_coordinator", + task_id="repair_coordinator", databricks_conn_id=self.databricks_conn_id, launch_task_id=create_databricks_workflow_task.task_id, max_full_run_repairs=self.max_full_run_repairs, diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 85d244f3edb98..5f74d12454ad0 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -585,7 +585,7 @@ def _make_operator( deferrable: bool = True, ) -> _DatabricksFullRunRepairCoordinatorOperator: return _DatabricksFullRunRepairCoordinatorOperator( - task_id="full_run_repair_coordinator", + task_id="repair_coordinator", databricks_conn_id="databricks_default", launch_task_id=self.LAUNCH_TASK_ID, max_full_run_repairs=max_full_run_repairs, @@ -709,7 +709,7 @@ def test_max_full_run_repairs_positive_injects_coordinator_with_launch_upstream( task._convert_to_databricks_workflow_task = MagicMock(return_value={}) tg.add(task) - coordinator = tg.children["wf.full_run_repair_coordinator"] + coordinator = tg.children["wf.repair_coordinator"] assert isinstance(coordinator, _DatabricksFullRunRepairCoordinatorOperator) assert coordinator.max_full_run_repairs == 2 assert coordinator.repair_polling_period_seconds == 15 From 79d9c2db7a1c98a1b76a65ce2a4c01326074fb8f Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 18 May 2026 15:58:05 -0500 Subject: [PATCH 03/15] more docstring trimming --- .../databricks/triggers/databricks.py | 53 +++++++------------ 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 1f3b6a7f24ec5..3880fcf20c2b0 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -137,24 +137,17 @@ class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): """ Coordinate whole-run polling and ``rerun_all_failed_tasks`` repair for a Databricks Workflow run. - Owned by the ``coordinator`` sibling task that + Used by the ``repair_coordinator`` sibling task that :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` - injects when ``max_full_run_repairs > 0`` on Airflow 3+. Keeps a single Databricks job run alive across - repair attempts so the same job cluster is reused. Each defer/resume cycle of the coordinator - task corresponds to one iteration: - - 1. Poll the run until it reaches a terminal state. - 2. On terminal success, yield ``status="completed"``. - 3. On terminal failure with repair budget remaining, call - :meth:`~airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run` with - ``rerun_all_failed_tasks=True`` and yield ``status="repaired"`` along with the new - ``latest_repair_id`` and bumped ``repair_attempts``; the coordinator task then re-defers on - a fresh trigger instance with the new state. Downstream task monitors observe the new sub-run - attempt by polling the Databricks API directly (via - :class:`DatabricksWorkflowRepairWaitTrigger`), not via any inter-task XCom. - 4. On terminal failure with the budget exhausted, yield ``status="failed"``. - - The Databricks ``run_id`` is stable across repair attempts; only ``latest_repair_id`` changes. + injects when ``max_full_run_repairs > 0`` on Airflow 3+. The trigger mirrors the coordinator + operator's sync state machine: it watches the stable parent ``run_id``, tracks repair progress + in trigger serialization state, and emits one event per terminal observation. + + On success it yields ``status="completed"``. On failure with budget remaining, it calls + ``repair_run`` with the shared repair payload helper and yields ``status="repaired"`` with the + bumped ``repair_attempts`` and new ``latest_repair_id`` so the coordinator can defer again. On + failure after the budget is exhausted, it yields ``status="failed"``. Downstream task monitors + discover repaired sub-runs from the Databricks API rather than from coordinator XCom. :param run_id: The Databricks run id to coordinate. :param databricks_conn_id: Airflow connection id for the Databricks hook. @@ -326,25 +319,15 @@ class DatabricksWorkflowRepairWaitTrigger(BaseTrigger): """ Wait for the next attempt of a single Databricks Workflow task after its sub-run failed. - Used by Databricks task monitors inside a + Used by deferrable Databricks task monitors inside a :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` - when ``max_full_run_repairs > 0`` on Airflow 3+. A monitor whose sub-run reaches terminal failure defers - on this trigger; the trigger polls the parent run's task list and yields when a new attempt of - the same ``task_key`` appears (issued by the sibling ``coordinator`` task via - ``rerun_all_failed_tasks``), so the monitor can then defer on a fresh - :class:`DatabricksExecutionTrigger` watching the new sub-run id. - - Each poll cycle: - - 1. If a Databricks task with our ``databricks_task_key`` exists whose ``run_id`` differs from - ``original_sub_run_id`` and whose ``start_time`` is newer, yield ``status="new_attempt"`` - with the new sub-run id. - 2. Otherwise, if the parent run is in a terminal failure state, count one "grace" observation. - After ``terminal_grace_polls`` consecutive terminal observations without a new attempt, - yield ``status="parent_failed"``. This avoids racing the coordinator: the parent run is - briefly terminal between sub-run failure and the coordinator issuing ``repair_run``. - 3. Otherwise (parent still running, or terminal but inside the grace window), sleep and poll - again. + when ``max_full_run_repairs > 0`` on Airflow 3+. After a sub-run fails, the trigger reads the + parent run's task list from ``get_run`` and yields ``status="new_attempt"`` when the shared + candidate-selection helper finds a newer sub-run with the same ``task_key``. + + If the parent run stays terminal-failed without a new attempt for ``terminal_grace_polls`` + consecutive polls, the trigger yields ``status="parent_failed"``. The grace window avoids + racing the coordinator while it issues ``repair_run``. :param run_id: Parent workflow run id (stable across repairs). :param databricks_conn_id: Airflow connection id for the Databricks hook. From 8f87d1845bb4a7464f7e76b2de1cd9b8ee1c5a2f Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 18 May 2026 16:32:32 -0500 Subject: [PATCH 04/15] More cleanup --- .../operators/databricks_workflow.py | 21 +++-------- .../databricks/triggers/databricks.py | 36 ++++--------------- 2 files changed, 10 insertions(+), 47 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index c53d0a9d2fdd2..3813b63ae5a15 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -310,31 +310,18 @@ def on_kill(self) -> None: class _DatabricksFullRunRepairCoordinatorOperator(BaseOperator): """ - Watch a Databricks Workflow run and trigger ``rerun_all_failed_tasks`` repairs. - - Runs as a sibling of the downstream Databricks task monitors inside a - :class:`DatabricksWorkflowTaskGroup`. The ``launch`` task creates or resets the job, starts the - run, and publishes ``{conn_id, job_id, run_id}`` so monitors can fan out. This operator then - owns the parent-run repair budget, waits for terminal run states, and issues one repair call per - failed batch. - - Downstream task monitors observe repairs independently from the Databricks API when their - sub-run fails; they do not share repair state through XCom. The coordinator's final return - value carries ``{run_id, repair_attempts, latest_repair_id}`` for any user code that wants a - post-run summary. + Watch a Databricks Workflow run and issue full-run repairs after terminal failures. :param task_id: The task id of the operator (typically ``"repair_coordinator"``). :param databricks_conn_id: Connection id used by the coordinator trigger and repair calls. - :param launch_task_id: The full task id of the workflow ``launch`` task whose return value - carries ``{conn_id, job_id, run_id}``. + :param launch_task_id: The workflow ``launch`` task whose XCom carries the parent run metadata. :param max_full_run_repairs: Total repair attempts allowed across the run. :param repair_polling_period_seconds: Poll interval used by the trigger or sync poll loop. :param databricks_retry_limit: Hook retry limit for transient API failures. :param databricks_retry_delay: Hook retry delay (seconds). :param databricks_retry_args: Optional ``tenacity.Retrying`` kwargs forwarded to the hook. - :param deferrable: If ``True``, watch the run by deferring on - :class:`DatabricksWorkflowRepairCoordinatorTrigger`. If ``False``, watch it via a - synchronous poll loop that runs the same state machine inline. + :param deferrable: Whether to watch the run with + :class:`DatabricksWorkflowRepairCoordinatorTrigger`. """ caller = "_DatabricksFullRunRepairCoordinatorOperator" diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 3880fcf20c2b0..22b38118dd445 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -135,26 +135,13 @@ async def run(self): class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): """ - Coordinate whole-run polling and ``rerun_all_failed_tasks`` repair for a Databricks Workflow run. - - Used by the ``repair_coordinator`` sibling task that - :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` - injects when ``max_full_run_repairs > 0`` on Airflow 3+. The trigger mirrors the coordinator - operator's sync state machine: it watches the stable parent ``run_id``, tracks repair progress - in trigger serialization state, and emits one event per terminal observation. - - On success it yields ``status="completed"``. On failure with budget remaining, it calls - ``repair_run`` with the shared repair payload helper and yields ``status="repaired"`` with the - bumped ``repair_attempts`` and new ``latest_repair_id`` so the coordinator can defer again. On - failure after the budget is exhausted, it yields ``status="failed"``. Downstream task monitors - discover repaired sub-runs from the Databricks API rather than from coordinator XCom. + Coordinate parent-run polling and full-run repairs for a Databricks Workflow run. :param run_id: The Databricks run id to coordinate. :param databricks_conn_id: Airflow connection id for the Databricks hook. :param max_full_run_repairs: Total repair attempts allowed for this run. - :param repair_attempts: Repair attempts already performed (defaults to 0 on the first defer). - :param latest_repair_id: Repair id of the most recent repair attempt, or ``None`` on the first - defer. Forwarded to ``repair_run`` so Databricks knows which attempt is the latest. + :param repair_attempts: Repair attempts already performed. + :param latest_repair_id: Repair id of the most recent repair attempt. :param polling_period_seconds: How often to poll the run state. :param retry_limit: Hook retry limit for transient Databricks API failures. :param retry_delay: Hook retry delay (seconds). @@ -317,17 +304,7 @@ async def run(self): class DatabricksWorkflowRepairWaitTrigger(BaseTrigger): """ - Wait for the next attempt of a single Databricks Workflow task after its sub-run failed. - - Used by deferrable Databricks task monitors inside a - :class:`~airflow.providers.databricks.operators.databricks_workflow.DatabricksWorkflowTaskGroup` - when ``max_full_run_repairs > 0`` on Airflow 3+. After a sub-run fails, the trigger reads the - parent run's task list from ``get_run`` and yields ``status="new_attempt"`` when the shared - candidate-selection helper finds a newer sub-run with the same ``task_key``. - - If the parent run stays terminal-failed without a new attempt for ``terminal_grace_polls`` - consecutive polls, the trigger yields ``status="parent_failed"``. The grace window avoids - racing the coordinator while it issues ``repair_run``. + Wait for the next attempt of a Databricks Workflow task after its sub-run fails. :param run_id: Parent workflow run id (stable across repairs). :param databricks_conn_id: Airflow connection id for the Databricks hook. @@ -335,9 +312,8 @@ class DatabricksWorkflowRepairWaitTrigger(BaseTrigger): :param original_sub_run_id: The sub-run id of the attempt that just failed; the trigger only yields ``new_attempt`` for a sub-run id different from this one. :param polling_period_seconds: How often to poll the parent run. - :param terminal_grace_polls: Number of consecutive terminal-with-no-new-attempt observations - required before yielding ``status="parent_failed"``. Bounds how long we wait for the - coordinator to issue a repair after observing terminal failure. + :param terminal_grace_polls: Consecutive terminal observations before yielding + ``status="parent_failed"``. :param retry_limit: Hook retry limit for transient Databricks API failures. :param retry_delay: Hook retry delay (seconds). :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. From c6f69565247738b6c251b072f934b840958f522a Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 19 May 2026 09:35:43 -0500 Subject: [PATCH 05/15] trim LOC, reduce exceptions, hardcode grace polls --- .../providers/databricks/exceptions.py | 12 ------- .../databricks/operators/databricks.py | 31 +++++-------------- .../operators/databricks_workflow.py | 29 ++++------------- .../databricks/triggers/databricks.py | 15 ++++----- .../providers/databricks/utils/databricks.py | 14 ++------- .../operators/test_databricks_workflow.py | 22 ------------- .../databricks/triggers/test_databricks.py | 11 ++----- 7 files changed, 24 insertions(+), 110 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/exceptions.py b/providers/databricks/src/airflow/providers/databricks/exceptions.py index 424ee8b7e45b3..85b519e83b948 100644 --- a/providers/databricks/src/airflow/providers/databricks/exceptions.py +++ b/providers/databricks/src/airflow/providers/databricks/exceptions.py @@ -34,15 +34,3 @@ class DatabricksSqlExecutionTimeout(DatabricksSqlExecutionError): class DatabricksWorkflowRepairError(AirflowException): """Raised when Databricks Workflow repair coordination fails.""" - - -class DatabricksWorkflowRepairMetadataError(DatabricksWorkflowRepairError): - """Raised when workflow repair metadata is missing or invalid.""" - - -class DatabricksWorkflowRepairBudgetExhausted(DatabricksWorkflowRepairError): - """Raised when a Databricks Workflow run fails after exhausting repair attempts.""" - - -class DatabricksWorkflowRepairTriggerError(DatabricksWorkflowRepairError): - """Raised when a Databricks Workflow repair trigger emits an invalid event.""" diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 4bdb5702b9c00..2cf0cf7b785e4 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -27,10 +27,7 @@ from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink, XCom, conf -from airflow.providers.databricks.exceptions import ( - DatabricksWorkflowRepairBudgetExhausted, - DatabricksWorkflowRepairTriggerError, -) +from airflow.providers.databricks.exceptions import DatabricksWorkflowRepairError from airflow.providers.databricks.hooks.databricks import ( DatabricksHook, RunLifeCycleState, @@ -1590,16 +1587,9 @@ def execute(self, context: Context) -> None: self.monitor_databricks_job() def _resolve_workflow_run_metadata(self, context: Context | dict | None) -> WorkflowRunMetadata: - """ - Populate ``databricks_run_id`` / ``databricks_conn_id`` from ``workflow_run_metadata``. - - Resolves both the standard ``execute`` path and the deferrable-resume path: when a task - resumes via ``execute_complete``, the operator is freshly re-instantiated, so any - attributes set during ``execute`` (including ``databricks_run_id``) are lost. The - templated ``workflow_run_metadata`` field is rendered from the upstream ``.launch`` - task's XCom on every task-instance run, so it is the canonical source for the parent - workflow run id across both entry points. - """ + # Called from both execute() and execute_complete(): the deferrable resume path + # re-instantiates the operator, so attributes set in execute() are lost and we + # re-resolve from the templated workflow_run_metadata field. if not self.workflow_run_metadata and context is not None: launch_task_id = next((task for task in self.upstream_task_ids if task.endswith(".launch")), None) ti = context.get("ti") if isinstance(context, dict) else context["ti"] @@ -1642,14 +1632,7 @@ def _sync_wait_for_new_sub_run_attempt( original_start_time: int | None, tg: DatabricksWorkflowTaskGroup, ) -> int | None: - """ - Sync equivalent of :class:`DatabricksWorkflowRepairWaitTrigger`. - - Polls the parent run for a new attempt of ``self.databricks_task_key`` after a - sub-run reaches terminal failure, then lets the caller switch to that attempt and - continue polling. Returns the new sub-run id, or ``None`` if the parent run - terminates without producing a new attempt within the grace window. - """ + """Sync mirror of DatabricksWorkflowRepairWaitTrigger. Returns the new sub-run id or None.""" self.log.info( "Sub-run %s for task_key %s reached terminal failure; waiting for a repair " "attempt issued by the workflow coordinator.", @@ -1739,14 +1722,14 @@ def execute_complete_after_repair_wait(self, context: dict | None, event: dict) ) elif status == "parent_failed": parent_state = RunState.from_json(event["parent_run_state"]) - raise DatabricksWorkflowRepairBudgetExhausted( + raise DatabricksWorkflowRepairError( f"Databricks workflow run {event.get('parent_run_id')} reached terminal failure " f"({parent_state.result_state}) without producing a new attempt for task_key " f"{self.databricks_task_key!r}; repair budget is exhausted or the coordinator " f"did not issue a repair." ) else: - raise DatabricksWorkflowRepairTriggerError( + raise DatabricksWorkflowRepairError( f"DatabricksWorkflowRepairWaitTrigger emitted unexpected status {status!r}: {event}" ) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 3813b63ae5a15..b4d287bf14676 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -27,11 +27,7 @@ from mergedeep import merge from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, TaskGroup, conf -from airflow.providers.databricks.exceptions import ( - DatabricksWorkflowRepairBudgetExhausted, - DatabricksWorkflowRepairMetadataError, - DatabricksWorkflowRepairTriggerError, -) +from airflow.providers.databricks.exceptions import DatabricksWorkflowRepairError from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState from airflow.providers.databricks.plugins.databricks_workflow import ( WorkflowJobRepairAllFailedLink, @@ -309,20 +305,7 @@ def on_kill(self) -> None: class _DatabricksFullRunRepairCoordinatorOperator(BaseOperator): - """ - Watch a Databricks Workflow run and issue full-run repairs after terminal failures. - - :param task_id: The task id of the operator (typically ``"repair_coordinator"``). - :param databricks_conn_id: Connection id used by the coordinator trigger and repair calls. - :param launch_task_id: The workflow ``launch`` task whose XCom carries the parent run metadata. - :param max_full_run_repairs: Total repair attempts allowed across the run. - :param repair_polling_period_seconds: Poll interval used by the trigger or sync poll loop. - :param databricks_retry_limit: Hook retry limit for transient API failures. - :param databricks_retry_delay: Hook retry delay (seconds). - :param databricks_retry_args: Optional ``tenacity.Retrying`` kwargs forwarded to the hook. - :param deferrable: Whether to watch the run with - :class:`DatabricksWorkflowRepairCoordinatorTrigger`. - """ + """Watch a Databricks Workflow run and issue full-run repairs after terminal failures.""" caller = "_DatabricksFullRunRepairCoordinatorOperator" @@ -385,7 +368,7 @@ def _make_trigger( def execute(self, context: Context) -> Any: launch_value = context["ti"].xcom_pull(task_ids=self.launch_task_id) if not launch_value: - raise DatabricksWorkflowRepairMetadataError( + raise DatabricksWorkflowRepairError( f"Launch task {self.launch_task_id!r} did not publish workflow run metadata; " "cannot coordinate repairs." ) @@ -434,7 +417,7 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: errors = extract_failed_task_errors(self._hook, run_info, run_state) if repair_attempts >= self.max_full_run_repairs: - raise DatabricksWorkflowRepairBudgetExhausted( + raise DatabricksWorkflowRepairError( f"Databricks workflow run {run_id} failed after {repair_attempts} repair " f"attempt(s); repair budget exhausted (max_full_run_repairs={self.max_full_run_repairs}). " f"Errors: {errors}" @@ -500,13 +483,13 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: if status == "failed": errors = event.get("errors", []) - raise DatabricksWorkflowRepairBudgetExhausted( + raise DatabricksWorkflowRepairError( f"Databricks workflow run {run_id} failed after {repair_attempts} repair " f"attempt(s); repair budget exhausted (max_full_run_repairs={self.max_full_run_repairs}). " f"Errors: {errors}" ) - raise DatabricksWorkflowRepairTriggerError( + raise DatabricksWorkflowRepairError( f"DatabricksWorkflowRepairCoordinatorTrigger emitted unexpected status {status!r}: {event}" ) diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 22b38118dd445..515642e84368c 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -312,8 +312,6 @@ class DatabricksWorkflowRepairWaitTrigger(BaseTrigger): :param original_sub_run_id: The sub-run id of the attempt that just failed; the trigger only yields ``new_attempt`` for a sub-run id different from this one. :param polling_period_seconds: How often to poll the parent run. - :param terminal_grace_polls: Consecutive terminal observations before yielding - ``status="parent_failed"``. :param retry_limit: Hook retry limit for transient Databricks API failures. :param retry_delay: Hook retry delay (seconds). :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. @@ -329,7 +327,6 @@ def __init__( original_sub_run_id: int, original_start_time: int | None = None, polling_period_seconds: int = 30, - terminal_grace_polls: int = 3, retry_limit: int = 3, retry_delay: int = 10, retry_args: dict[Any, Any] | None = None, @@ -337,15 +334,12 @@ def __init__( caller: str = "DatabricksWorkflowRepairWaitTrigger", ) -> None: super().__init__() - if terminal_grace_polls < 1: - raise ValueError(f"terminal_grace_polls must be >= 1, got {terminal_grace_polls}") self.run_id = run_id self.databricks_conn_id = databricks_conn_id self.databricks_task_key = databricks_task_key self.original_sub_run_id = original_sub_run_id self.original_start_time = original_start_time self.polling_period_seconds = polling_period_seconds - self.terminal_grace_polls = terminal_grace_polls self.retry_limit = retry_limit self.retry_delay = retry_delay self.retry_args = retry_args @@ -369,7 +363,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "original_sub_run_id": self.original_sub_run_id, "original_start_time": self.original_start_time, "polling_period_seconds": self.polling_period_seconds, - "terminal_grace_polls": self.terminal_grace_polls, "retry_limit": self.retry_limit, "retry_delay": self.retry_delay, "retry_args": self.retry_args, @@ -387,6 +380,10 @@ def _find_new_attempt(self, tasks: list[dict[str, Any]]) -> dict[str, Any] | Non ) async def run(self): + # Grace polls before declaring the parent terminally failed without a new attempt: + # Databricks can briefly report the parent run terminal before a repair-triggered + # sub-run shows up in the tasks list. + terminal_grace_polls = 3 terminal_observations = 0 async with self.hook: while True: @@ -423,9 +420,9 @@ async def run(self): run_state.result_state, self.databricks_task_key, terminal_observations, - self.terminal_grace_polls, + terminal_grace_polls, ) - if terminal_observations >= self.terminal_grace_polls: + if terminal_observations >= terminal_grace_polls: yield TriggerEvent( { "status": "parent_failed", diff --git a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py index 4c1f557b408be..0607ca0de3021 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py @@ -113,12 +113,7 @@ def find_new_workflow_task_attempt( original_sub_run_id: int, original_start_time: int | None, ) -> dict[str, Any] | None: - """ - Return the newest task entry matching ``task_key`` that is not the original sub-run. - - Used by the repair-wait trigger and its sync counterpart to detect a new attempt of - a Databricks Workflow task after the prior sub-run reached terminal failure. - """ + """Return the newest task entry matching ``task_key`` that is not the original sub-run.""" candidates = [ task for task in tasks @@ -136,12 +131,7 @@ def build_repair_run_json( latest_repair_id: int | None, overriding_parameters: Any = None, ) -> dict[str, Any]: - """ - Build the ``DatabricksHook.repair_run`` payload for ``rerun_all_failed_tasks`` repair. - - Used by the coordinator trigger and its sync counterpart to keep the repair payload - shape in lock-step. - """ + """Build the ``DatabricksHook.repair_run`` payload for ``rerun_all_failed_tasks`` repair.""" repair_json: dict[str, Any] = { "run_id": run_id, "rerun_all_failed_tasks": True, diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 5f74d12454ad0..6f661c02971eb 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -715,25 +715,3 @@ def test_max_full_run_repairs_positive_injects_coordinator_with_launch_upstream( assert coordinator.repair_polling_period_seconds == 15 assert coordinator.launch_task_id == "wf.launch" assert "wf.launch" in coordinator.upstream_task_ids - - def test_negative_max_full_run_repairs_rejected(self): - with pytest.raises(ValueError, match="max_full_run_repairs must be >= 0"): - DatabricksWorkflowTaskGroup( - group_id="wf_invalid", - databricks_conn_id="databricks_conn", - max_full_run_repairs=-1, - ) - - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Repair behavior only applies on Airflow 3+") - def test_warns_when_user_task_has_retries_and_max_full_run_repairs_positive(self): - with DAG(dag_id="dwf_warn_retries", schedule=None, start_date=DEFAULT_DATE): - tg = DatabricksWorkflowTaskGroup( - group_id="wf", - databricks_conn_id="databricks_conn", - max_full_run_repairs=1, - ) - tg.__enter__() - task = EmptyOperator(task_id="task1", retries=2) - task._convert_to_databricks_workflow_task = MagicMock(return_value={}) - with pytest.warns(UserWarning, match=r"retries=2 while max_full_run_repairs=1"): - tg.__exit__(None, None, None) diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 12b0b4c60e691..73bfd49d955dd 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -598,17 +598,13 @@ def setup_connections(self, create_connection_without_db): ) ) - def _make_trigger( - self, - terminal_grace_polls: int = 3, - ) -> DatabricksWorkflowRepairWaitTrigger: + def _make_trigger(self) -> DatabricksWorkflowRepairWaitTrigger: return DatabricksWorkflowRepairWaitTrigger( run_id=self.PARENT_RUN_ID, databricks_conn_id=DEFAULT_CONN_ID, databricks_task_key=self.TASK_KEY, original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, polling_period_seconds=POLLING_INTERVAL_SECONDS, - terminal_grace_polls=terminal_grace_polls, run_page_url=RUN_PAGE_URL, ) @@ -629,7 +625,7 @@ def _run_payload( } def test_serialize_round_trips_state(self): - trigger = self._make_trigger(terminal_grace_polls=5) + trigger = self._make_trigger() path, kwargs = trigger.serialize() assert path == "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger" @@ -640,7 +636,6 @@ def test_serialize_round_trips_state(self): "original_sub_run_id": self.ORIGINAL_SUB_RUN_ID, "original_start_time": None, "polling_period_seconds": POLLING_INTERVAL_SECONDS, - "terminal_grace_polls": 5, "retry_limit": RETRY_LIMIT, "retry_delay": RETRY_DELAY, "retry_args": None, @@ -697,7 +692,7 @@ async def test_emits_parent_failed_after_grace_polls(self, mock_get_run, mock_sl ) mock_get_run.return_value = terminal_payload - trigger = self._make_trigger(terminal_grace_polls=3) + trigger = self._make_trigger() events = [event async for event in trigger.run()] assert mock_get_run.call_count == 3 From 52f164403b343bb4852ccf50412443129eb00e57 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 19 May 2026 09:58:50 -0500 Subject: [PATCH 06/15] tweak seralization test --- .../databricks/triggers/test_databricks.py | 32 ++++--------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 73bfd49d955dd..1705560e92385 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -471,25 +471,15 @@ def _make_trigger( def test_serialize_round_trips_state(self): trigger = self._make_trigger(max_full_run_repairs=3, repair_attempts=1, latest_repair_id=42) + path, kwargs = trigger.serialize() + restored = DatabricksWorkflowRepairCoordinatorTrigger(**kwargs) assert ( path == "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairCoordinatorTrigger" ) - assert kwargs == { - "run_id": RUN_ID, - "databricks_conn_id": DEFAULT_CONN_ID, - "max_full_run_repairs": 3, - "repair_attempts": 1, - "latest_repair_id": 42, - "polling_period_seconds": POLLING_INTERVAL_SECONDS, - "retry_limit": RETRY_LIMIT, - "retry_delay": RETRY_DELAY, - "retry_args": None, - "run_page_url": RUN_PAGE_URL, - "caller": "DatabricksWorkflowRepairCoordinatorTrigger", - } + assert restored.serialize() == (path, kwargs) @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") @@ -626,22 +616,12 @@ def _run_payload( def test_serialize_round_trips_state(self): trigger = self._make_trigger() + path, kwargs = trigger.serialize() + restored = DatabricksWorkflowRepairWaitTrigger(**kwargs) assert path == "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger" - assert kwargs == { - "run_id": self.PARENT_RUN_ID, - "databricks_conn_id": DEFAULT_CONN_ID, - "databricks_task_key": self.TASK_KEY, - "original_sub_run_id": self.ORIGINAL_SUB_RUN_ID, - "original_start_time": None, - "polling_period_seconds": POLLING_INTERVAL_SECONDS, - "retry_limit": RETRY_LIMIT, - "retry_delay": RETRY_DELAY, - "retry_args": None, - "run_page_url": RUN_PAGE_URL, - "caller": "DatabricksWorkflowRepairWaitTrigger", - } + assert restored.serialize() == (path, kwargs) @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") From f2777fc2043b26a13925a9ba15b8f41f0c68c207 Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 20 May 2026 08:56:54 -0500 Subject: [PATCH 07/15] keep language consistent --- .../airflow/providers/databricks/operators/databricks.py | 8 ++++---- .../providers/databricks/operators/databricks_workflow.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 2cf0cf7b785e4..0285dac1d4bf3 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -1539,7 +1539,7 @@ def monitor_databricks_job(self) -> None: if new_sub_run_id is None: break self.log.info( - "Workflow coordinator produced a new attempt for task_key %s (sub-run %s).", + "Repair coordinator produced a new attempt for task_key %s (sub-run %s).", self.databricks_task_key, new_sub_run_id, ) @@ -1635,7 +1635,7 @@ def _sync_wait_for_new_sub_run_attempt( """Sync mirror of DatabricksWorkflowRepairWaitTrigger. Returns the new sub-run id or None.""" self.log.info( "Sub-run %s for task_key %s reached terminal failure; waiting for a repair " - "attempt issued by the workflow coordinator.", + "attempt issued by the repair coordinator.", original_sub_run_id, self.databricks_task_key, ) @@ -1678,7 +1678,7 @@ def _defer_to_workflow_repair_wait( ) -> None: self.log.info( "Sub-run %s for task_key %s reached terminal failure; deferring to wait for a repair " - "attempt issued by the workflow coordinator.", + "attempt issued by the repair coordinator.", original_sub_run_id, self.databricks_task_key, ) @@ -1703,7 +1703,7 @@ def execute_complete_after_repair_wait(self, context: dict | None, event: dict) if status == "new_attempt": new_sub_run_id = event["new_sub_run_id"] self.log.info( - "Workflow coordinator produced a new attempt for task_key %s (sub-run %s); " + "Repair coordinator produced a new attempt for task_key %s (sub-run %s); " "deferring on a fresh DatabricksExecutionTrigger to monitor it.", self.databricks_task_key, new_sub_run_id, diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index b4d287bf14676..2f83370d89722 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -324,7 +324,7 @@ def __init__( ): if max_full_run_repairs < 1: raise ValueError( - f"max_full_run_repairs must be >= 1 for the workflow coordinator task, got {max_full_run_repairs}" + f"max_full_run_repairs must be >= 1 for the repair coordinator task, got {max_full_run_repairs}" ) super().__init__(task_id=task_id, **kwargs) self.databricks_conn_id = databricks_conn_id From bb6b77331485f0314cbca66a8f49be99d76df2ef Mon Sep 17 00:00:00 2001 From: Nick Date: Wed, 20 May 2026 12:46:13 -0500 Subject: [PATCH 08/15] rename params for clarity --- .../databricks/docs/operators/workflow.rst | 20 ++++++ .../databricks/operators/databricks.py | 6 +- .../operators/databricks_workflow.py | 67 ++++++++++--------- .../databricks/triggers/databricks.py | 18 ++--- .../databricks/operators/test_databricks.py | 20 +++--- .../operators/test_databricks_workflow.py | 40 +++++------ .../databricks/triggers/test_databricks.py | 12 ++-- 7 files changed, 102 insertions(+), 81 deletions(-) diff --git a/providers/databricks/docs/operators/workflow.rst b/providers/databricks/docs/operators/workflow.rst index 42cb0b20f9666..18e0889b8ed7c 100644 --- a/providers/databricks/docs/operators/workflow.rst +++ b/providers/databricks/docs/operators/workflow.rst @@ -68,3 +68,23 @@ To minimize update conflicts, we recommend that you keep parameters in the ``not ``DatabricksWorkflowTaskGroup`` and not in the ``DatabricksNotebookOperator`` whenever possible. This is because, tasks in the ``DatabricksWorkflowTaskGroup`` are passed in on the job trigger time and do not modify the job definition. + +Automatic repair (Airflow 3+) +----------------------------- + +Set ``workflow_repair_attempts=N`` to auto-repair a failed workflow run up to N times. The task +group injects a ``repair_coordinator`` sibling that waits for the run to terminate and then calls +Databricks ``repair_run`` with ``rerun_all_failed_tasks=True``. Default is ``0`` (off). + +.. code-block:: python + + with DatabricksWorkflowTaskGroup( + group_id="example_databricks_workflow", + databricks_conn_id="databricks_default", + job_clusters=job_clusters, + workflow_repair_attempts=2, + ) as task_group: + ... + +Downstream task monitors stay in the same Airflow attempt across repairs, so size +``execution_timeout`` for the original run plus all repairs and set ``retries=0`` on workflow tasks. diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 0285dac1d4bf3..65eaae865faf0 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -1622,7 +1622,7 @@ def _workflow_task_group_with_repair(self) -> DatabricksWorkflowTaskGroup | None if not AIRFLOW_V_3_0_PLUS: return None tg = self._databricks_workflow_task_group - if tg is None or getattr(tg, "max_full_run_repairs", 0) <= 0: + if tg is None or getattr(tg, "workflow_repair_attempts", 0) <= 0: return None return tg @@ -1639,7 +1639,7 @@ def _sync_wait_for_new_sub_run_attempt( original_sub_run_id, self.databricks_task_key, ) - polling_period_seconds = tg.repair_polling_period_seconds + polling_period_seconds = tg.workflow_repair_polling_period terminal_grace_polls = 3 terminal_observations = 0 while True: @@ -1689,7 +1689,7 @@ def _defer_to_workflow_repair_wait( databricks_task_key=self.databricks_task_key, original_sub_run_id=original_sub_run_id, original_start_time=original_start_time, - polling_period_seconds=tg.repair_polling_period_seconds, + polling_period_seconds=tg.workflow_repair_polling_period, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 2f83370d89722..c30084b1aef91 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -304,33 +304,33 @@ def on_kill(self) -> None: ) -class _DatabricksFullRunRepairCoordinatorOperator(BaseOperator): - """Watch a Databricks Workflow run and issue full-run repairs after terminal failures.""" +class _DatabricksWorkflowRepairCoordinatorOperator(BaseOperator): + """Watch a Databricks Workflow run and issue repairs after terminal failures.""" - caller = "_DatabricksFullRunRepairCoordinatorOperator" + caller = "_DatabricksWorkflowRepairCoordinatorOperator" def __init__( self, task_id: str, databricks_conn_id: str, launch_task_id: str, - max_full_run_repairs: int, - repair_polling_period_seconds: int = 30, + workflow_repair_attempts: int, + workflow_repair_polling_period: int = 30, databricks_retry_limit: int = 3, databricks_retry_delay: int = 10, databricks_retry_args: dict[Any, Any] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): - if max_full_run_repairs < 1: + if workflow_repair_attempts < 1: raise ValueError( - f"max_full_run_repairs must be >= 1 for the repair coordinator task, got {max_full_run_repairs}" + f"workflow_repair_attempts must be >= 1 for the repair coordinator task, got {workflow_repair_attempts}" ) super().__init__(task_id=task_id, **kwargs) self.databricks_conn_id = databricks_conn_id self.launch_task_id = launch_task_id - self.max_full_run_repairs = max_full_run_repairs - self.repair_polling_period_seconds = repair_polling_period_seconds + self.workflow_repair_attempts = workflow_repair_attempts + self.workflow_repair_polling_period = workflow_repair_polling_period self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args @@ -355,10 +355,10 @@ def _make_trigger( return DatabricksWorkflowRepairCoordinatorTrigger( run_id=run_id, databricks_conn_id=self.databricks_conn_id, - max_full_run_repairs=self.max_full_run_repairs, + workflow_repair_attempts=self.workflow_repair_attempts, repair_attempts=repair_attempts, latest_repair_id=latest_repair_id, - polling_period_seconds=self.repair_polling_period_seconds, + polling_period_seconds=self.workflow_repair_polling_period, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, @@ -383,6 +383,7 @@ def execute(self, context: Context) -> Any: ), method_name="execute_complete", ) + return None return self._run_sync(metadata.run_id) def _run_sync(self, run_id: int) -> dict[str, Any]: @@ -396,9 +397,9 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: "Databricks run %s in state %s. Sleeping for %s seconds.", run_id, run_state, - self.repair_polling_period_seconds, + self.workflow_repair_polling_period, ) - time.sleep(self.repair_polling_period_seconds) + time.sleep(self.workflow_repair_polling_period) run_state = self._hook.get_run_state(run_id) if run_state.is_successful: @@ -416,10 +417,10 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: run_info = self._hook.get_run(run_id) errors = extract_failed_task_errors(self._hook, run_info, run_state) - if repair_attempts >= self.max_full_run_repairs: + if repair_attempts >= self.workflow_repair_attempts: raise DatabricksWorkflowRepairError( f"Databricks workflow run {run_id} failed after {repair_attempts} repair " - f"attempt(s); repair budget exhausted (max_full_run_repairs={self.max_full_run_repairs}). " + f"attempt(s); repair budget exhausted (workflow_repair_attempts={self.workflow_repair_attempts}). " f"Errors: {errors}" ) @@ -429,7 +430,7 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: run_id, run_state.result_state, repair_attempts + 1, - self.max_full_run_repairs, + self.workflow_repair_attempts, latest_repair_id, ) repair_json = build_repair_run_json( @@ -485,7 +486,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: errors = event.get("errors", []) raise DatabricksWorkflowRepairError( f"Databricks workflow run {run_id} failed after {repair_attempts} repair " - f"attempt(s); repair budget exhausted (max_full_run_repairs={self.max_full_run_repairs}). " + f"attempt(s); repair budget exhausted (workflow_repair_attempts={self.workflow_repair_attempts}). " f"Errors: {errors}" ) @@ -532,12 +533,12 @@ class DatabricksWorkflowTaskGroup(TaskGroup): all python tasks in the workflow. :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters will be passed to all spark submit tasks. - :param max_full_run_repairs: Maximum number of automatic ``rerun_all_failed_tasks`` repair attempts to + :param workflow_repair_attempts: Maximum number of automatic ``rerun_all_failed_tasks`` repair attempts to issue against the Databricks run when downstream tasks fail. Each repair reuses the original job cluster. Set to ``0`` to disable auto-repair (current behavior). Only takes effect on Airflow 3+; ignored on Airflow 2.x. Defaults to ``0``. - :param repair_polling_period_seconds: How often the repair coordinator polls the - Databricks run state. Only used when ``max_full_run_repairs > 0``. + :param workflow_repair_polling_period: How often the repair coordinator polls the + Databricks run state. Only used when ``workflow_repair_attempts > 0``. """ is_databricks = True @@ -555,12 +556,12 @@ def __init__( notebook_params: dict | None = None, python_params: list | None = None, spark_submit_params: list | None = None, - max_full_run_repairs: int = 0, - repair_polling_period_seconds: int = 30, + workflow_repair_attempts: int = 0, + workflow_repair_polling_period: int = 30, **kwargs, ): - if max_full_run_repairs < 0: - raise ValueError(f"max_full_run_repairs must be >= 0, got {max_full_run_repairs}") + if workflow_repair_attempts < 0: + raise ValueError(f"workflow_repair_attempts must be >= 0, got {workflow_repair_attempts}") self.databricks_conn_id = databricks_conn_id self.access_control_list = access_control_list self.existing_clusters = existing_clusters or [] @@ -572,8 +573,8 @@ def __init__( self.notebook_params = notebook_params or {} self.python_params = python_params or [] self.spark_submit_params = spark_submit_params or [] - self.max_full_run_repairs = max_full_run_repairs - self.repair_polling_period_seconds = repair_polling_period_seconds + self.workflow_repair_attempts = workflow_repair_attempts + self.workflow_repair_polling_period = workflow_repair_polling_period super().__init__(**kwargs) def __exit__( @@ -609,18 +610,18 @@ def __exit__( f"Task {task.task_id} does not support conversion to databricks workflow task." ) - if AIRFLOW_V_3_0_PLUS and self.max_full_run_repairs > 0: + if AIRFLOW_V_3_0_PLUS and self.workflow_repair_attempts > 0: task_retries = getattr(task, "retries", 0) if isinstance(task_retries, int) and task_retries > 0: warnings.warn( f"Task {task.task_id!r} in DatabricksWorkflowTaskGroup " f"{self.group_id!r} has retries={task_retries} while " - f"max_full_run_repairs={self.max_full_run_repairs}. Databricks-side repair supersedes " + f"workflow_repair_attempts={self.workflow_repair_attempts}. Databricks-side repair supersedes " "task-level retries for sub-run failures: a failed sub-run defers the " "monitor on the repair-wait trigger and resumes in the same Airflow " "attempt. Retries on the monitor will not trigger additional repairs " "and only add cost on non-repair-related transient failures. Consider " - "setting retries=0 on workflow tasks when max_full_run_repairs > 0.", + "setting retries=0 on workflow tasks when workflow_repair_attempts > 0.", UserWarning, stacklevel=2, ) @@ -631,15 +632,15 @@ def __exit__( for root_task in roots: root_task.set_upstream(create_databricks_workflow_task) - if AIRFLOW_V_3_0_PLUS and self.max_full_run_repairs > 0: - repair_coordinator_task = _DatabricksFullRunRepairCoordinatorOperator( + if AIRFLOW_V_3_0_PLUS and self.workflow_repair_attempts > 0: + repair_coordinator_task = _DatabricksWorkflowRepairCoordinatorOperator( dag=self.dag, task_group=self, task_id="repair_coordinator", databricks_conn_id=self.databricks_conn_id, launch_task_id=create_databricks_workflow_task.task_id, - max_full_run_repairs=self.max_full_run_repairs, - repair_polling_period_seconds=self.repair_polling_period_seconds, + workflow_repair_attempts=self.workflow_repair_attempts, + workflow_repair_polling_period=self.workflow_repair_polling_period, # Retrying the coordinator would re-enter execute() with repair_attempts=0 # and start the budget over. retries=0, diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 515642e84368c..14fa6f14c53b6 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -135,11 +135,11 @@ async def run(self): class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): """ - Coordinate parent-run polling and full-run repairs for a Databricks Workflow run. + Coordinate parent-run polling and repairs for a Databricks Workflow run. :param run_id: The Databricks run id to coordinate. :param databricks_conn_id: Airflow connection id for the Databricks hook. - :param max_full_run_repairs: Total repair attempts allowed for this run. + :param workflow_repair_attempts: Total repair attempts allowed for this run. :param repair_attempts: Repair attempts already performed. :param latest_repair_id: Repair id of the most recent repair attempt. :param polling_period_seconds: How often to poll the run state. @@ -154,7 +154,7 @@ def __init__( self, run_id: int, databricks_conn_id: str, - max_full_run_repairs: int, + workflow_repair_attempts: int, repair_attempts: int = 0, latest_repair_id: int | None = None, polling_period_seconds: int = 30, @@ -167,7 +167,7 @@ def __init__( super().__init__() self.run_id = run_id self.databricks_conn_id = databricks_conn_id - self.max_full_run_repairs = max_full_run_repairs + self.workflow_repair_attempts = workflow_repair_attempts self.repair_attempts = repair_attempts self.latest_repair_id = latest_repair_id self.polling_period_seconds = polling_period_seconds @@ -190,7 +190,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "run_id": self.run_id, "databricks_conn_id": self.databricks_conn_id, - "max_full_run_repairs": self.max_full_run_repairs, + "workflow_repair_attempts": self.workflow_repair_attempts, "repair_attempts": self.repair_attempts, "latest_repair_id": self.latest_repair_id, "polling_period_seconds": self.polling_period_seconds, @@ -244,13 +244,13 @@ async def run(self): ) return - if self.repair_attempts >= self.max_full_run_repairs: + if self.repair_attempts >= self.workflow_repair_attempts: self.log.info( "Databricks run %s reached terminal failure state %s and repair budget " - "is exhausted (max_full_run_repairs=%s).", + "is exhausted (workflow_repair_attempts=%s).", self.run_id, run_state.result_state, - self.max_full_run_repairs, + self.workflow_repair_attempts, ) yield TriggerEvent( { @@ -271,7 +271,7 @@ async def run(self): self.run_id, run_state.result_state, self.repair_attempts + 1, - self.max_full_run_repairs, + self.workflow_repair_attempts, self.latest_repair_id, ) diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 43786558647d6..d1fddfcf2f325 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -3012,7 +3012,7 @@ def _terminal_failure_event(sub_run_id: int) -> dict[str, Any]: "errors": [{"task_key": "tk", "run_id": sub_run_id, "error": "boom"}], } - def _operator_with_workflow_tg(self, max_full_run_repairs: int) -> DatabricksNotebookOperator: + def _operator_with_workflow_tg(self, workflow_repair_attempts: int) -> DatabricksNotebookOperator: # Pass workflow_run_metadata (templated field) rather than setting # databricks_run_id directly. Airflow re-instantiates the operator at deferrable # resume time, so execute_complete must rederive the parent run id from the @@ -3033,14 +3033,14 @@ def _operator_with_workflow_tg(self, max_full_run_repairs: int) -> DatabricksNot # Hand it a single-hop chain so it finds our mocked group directly. tg = MagicMock() tg.is_databricks = True - tg.max_full_run_repairs = max_full_run_repairs - tg.repair_polling_period_seconds = 15 + tg.workflow_repair_attempts = workflow_repair_attempts + tg.workflow_repair_polling_period = 15 tg.task_group = None operator.task_group = tg return operator - def test_execute_complete_failure_defers_on_wait_trigger_when_max_full_run_repairs_set(self): - operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + def test_execute_complete_failure_defers_on_wait_trigger_when_workflow_repair_attempts_set(self): + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) with pytest.raises(TaskDeferred) as exc: operator.execute_complete( @@ -3068,8 +3068,8 @@ def test_execute_complete_failure_pulls_workflow_metadata_from_xcom(self): ) tg = MagicMock() tg.is_databricks = True - tg.max_full_run_repairs = 2 - tg.repair_polling_period_seconds = 15 + tg.workflow_repair_attempts = 2 + tg.workflow_repair_polling_period = 15 tg.task_group = None operator.task_group = tg operator.upstream_task_ids = {"workflow.launch"} @@ -3092,7 +3092,7 @@ def test_execute_complete_failure_pulls_workflow_metadata_from_xcom(self): assert trigger.run_id == self.PARENT_RUN_ID def test_execute_complete_after_repair_wait_new_attempt_defers_on_execution_trigger(self): - operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) with pytest.raises(TaskDeferred) as exc: operator.execute_complete_after_repair_wait( @@ -3112,7 +3112,7 @@ def test_execute_complete_after_repair_wait_new_attempt_defers_on_execution_trig assert exc.value.method_name == "execute_complete" def test_execute_complete_after_repair_wait_parent_failed_raises(self): - operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) with pytest.raises(AirflowException, match="repair budget is exhausted"): operator.execute_complete_after_repair_wait( @@ -3132,7 +3132,7 @@ def test_execute_complete_after_repair_wait_parent_failed_raises(self): @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") def test_sync_wait_for_new_sub_run_attempt_returns_new_attempt(self, mock_sleep): - operator = self._operator_with_workflow_tg(max_full_run_repairs=2) + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) hook = MagicMock() operator.__dict__["_hook"] = hook operator.databricks_run_id = self.PARENT_RUN_ID diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 6f661c02971eb..3a535211ac263 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -36,7 +36,7 @@ DatabricksWorkflowTaskGroup, WorkflowRunMetadata, _CreateDatabricksWorkflowOperator, - _DatabricksFullRunRepairCoordinatorOperator, + _DatabricksWorkflowRepairCoordinatorOperator, _flatten_node, ) from airflow.providers.databricks.triggers.databricks import DatabricksWorkflowRepairCoordinatorTrigger @@ -575,21 +575,21 @@ def test_reset_job_payload_carries_parent_depends_on(self, mock_databricks_hook) self._assert_parent_depends_on(job_spec) -class TestDatabricksFullRunRepairCoordinatorOperator: +class TestDatabricksWorkflowRepairCoordinatorOperator: LAUNCH_TASK_ID = "wf.launch" LAUNCH_RETURN = {"conn_id": "databricks_default", "job_id": 42, "run_id": 100} def _make_operator( self, - max_full_run_repairs: int = 2, + workflow_repair_attempts: int = 2, deferrable: bool = True, - ) -> _DatabricksFullRunRepairCoordinatorOperator: - return _DatabricksFullRunRepairCoordinatorOperator( + ) -> _DatabricksWorkflowRepairCoordinatorOperator: + return _DatabricksWorkflowRepairCoordinatorOperator( task_id="repair_coordinator", databricks_conn_id="databricks_default", launch_task_id=self.LAUNCH_TASK_ID, - max_full_run_repairs=max_full_run_repairs, - repair_polling_period_seconds=10, + workflow_repair_attempts=workflow_repair_attempts, + workflow_repair_polling_period=10, deferrable=deferrable, ) @@ -602,7 +602,7 @@ def test_execute_raises_when_launch_xcom_missing(self): operator.execute(ctx) def test_execute_defers_on_coordinator_trigger(self): - operator = self._make_operator(max_full_run_repairs=3) + operator = self._make_operator(workflow_repair_attempts=3) ctx = {"ti": MagicMock()} ctx["ti"].xcom_pull.return_value = self.LAUNCH_RETURN @@ -614,13 +614,13 @@ def test_execute_defers_on_coordinator_trigger(self): trigger = exc.value.trigger assert isinstance(trigger, DatabricksWorkflowRepairCoordinatorTrigger) assert trigger.run_id == self.LAUNCH_RETURN["run_id"] - assert trigger.max_full_run_repairs == 3 + assert trigger.workflow_repair_attempts == 3 assert trigger.repair_attempts == 0 assert trigger.latest_repair_id is None assert trigger.polling_period_seconds == 10 def test_execute_complete_repaired_redefers_without_xcom_push(self): - operator = self._make_operator(max_full_run_repairs=3) + operator = self._make_operator(workflow_repair_attempts=3) ctx = {"ti": MagicMock()} with pytest.raises(TaskDeferred) as exc: @@ -640,10 +640,10 @@ def test_execute_complete_repaired_redefers_without_xcom_push(self): assert trigger.run_id == 100 assert trigger.repair_attempts == 1 assert trigger.latest_repair_id == 555 - assert trigger.max_full_run_repairs == 3 + assert trigger.workflow_repair_attempts == 3 def test_execute_complete_failed_raises_with_errors_in_message(self): - operator = self._make_operator(max_full_run_repairs=2) + operator = self._make_operator(workflow_repair_attempts=2) ctx = {"ti": MagicMock()} errors = [{"task_key": "t1", "run_id": 11, "error": "boom"}] @@ -661,13 +661,13 @@ def test_execute_complete_failed_raises_with_errors_in_message(self): message = str(exc.value) assert "100" in message - assert "max_full_run_repairs=2" in message + assert "workflow_repair_attempts=2" in message assert "boom" in message ctx["ti"].xcom_push.assert_not_called() @patch("airflow.providers.databricks.operators.databricks_workflow.time.sleep") def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): - operator = self._make_operator(max_full_run_repairs=2, deferrable=False) + operator = self._make_operator(workflow_repair_attempts=2, deferrable=False) hook = MagicMock() operator.__dict__["_hook"] = hook hook.get_run_state.side_effect = [ @@ -697,21 +697,21 @@ def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): class TestDatabricksWorkflowTaskGroupCoordinatorInjection: @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Coordinator task is only injected on Airflow 3+") - def test_max_full_run_repairs_positive_injects_coordinator_with_launch_upstream(self): + def test_workflow_repair_attempts_positive_injects_coordinator_with_launch_upstream(self): with DAG(dag_id="dwf_with_coord", schedule=None, start_date=DEFAULT_DATE): with DatabricksWorkflowTaskGroup( group_id="wf", databricks_conn_id="databricks_conn", - max_full_run_repairs=2, - repair_polling_period_seconds=15, + workflow_repair_attempts=2, + workflow_repair_polling_period=15, ) as tg: task = MagicMock(task_id="task1") task._convert_to_databricks_workflow_task = MagicMock(return_value={}) tg.add(task) coordinator = tg.children["wf.repair_coordinator"] - assert isinstance(coordinator, _DatabricksFullRunRepairCoordinatorOperator) - assert coordinator.max_full_run_repairs == 2 - assert coordinator.repair_polling_period_seconds == 15 + assert isinstance(coordinator, _DatabricksWorkflowRepairCoordinatorOperator) + assert coordinator.workflow_repair_attempts == 2 + assert coordinator.workflow_repair_polling_period == 15 assert coordinator.launch_task_id == "wf.launch" assert "wf.launch" in coordinator.upstream_task_ids diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 1705560e92385..707f74f759600 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -455,14 +455,14 @@ def setup_connections(self, create_connection_without_db): def _make_trigger( self, - max_full_run_repairs: int = 2, + workflow_repair_attempts: int = 2, repair_attempts: int = 0, latest_repair_id: int | None = None, ) -> DatabricksWorkflowRepairCoordinatorTrigger: return DatabricksWorkflowRepairCoordinatorTrigger( run_id=RUN_ID, databricks_conn_id=DEFAULT_CONN_ID, - max_full_run_repairs=max_full_run_repairs, + workflow_repair_attempts=workflow_repair_attempts, repair_attempts=repair_attempts, latest_repair_id=latest_repair_id, polling_period_seconds=POLLING_INTERVAL_SECONDS, @@ -470,7 +470,7 @@ def _make_trigger( ) def test_serialize_round_trips_state(self): - trigger = self._make_trigger(max_full_run_repairs=3, repair_attempts=1, latest_repair_id=42) + trigger = self._make_trigger(workflow_repair_attempts=3, repair_attempts=1, latest_repair_id=42) path, kwargs = trigger.serialize() restored = DatabricksWorkflowRepairCoordinatorTrigger(**kwargs) @@ -492,7 +492,7 @@ async def test_emits_completed_when_run_succeeds(self, mock_get_run_state, mock_ ) mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED - trigger = self._make_trigger(max_full_run_repairs=2, repair_attempts=0, latest_repair_id=None) + trigger = self._make_trigger(workflow_repair_attempts=2, repair_attempts=0, latest_repair_id=None) events = [event async for event in trigger.run()] assert len(events) == 1 @@ -519,7 +519,7 @@ async def test_first_failure_within_budget_calls_repair_and_emits_repaired( mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE mock_repair_run.return_value = 101 - trigger = self._make_trigger(max_full_run_repairs=2, repair_attempts=0, latest_repair_id=None) + trigger = self._make_trigger(workflow_repair_attempts=2, repair_attempts=0, latest_repair_id=None) events = [event async for event in trigger.run()] assert len(events) == 1 @@ -555,7 +555,7 @@ async def test_emits_failed_when_budget_exhausted( mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE - trigger = self._make_trigger(max_full_run_repairs=2, repair_attempts=2, latest_repair_id=202) + trigger = self._make_trigger(workflow_repair_attempts=2, repair_attempts=2, latest_repair_id=202) events = [event async for event in trigger.run()] assert len(events) == 1 From da2336329939d78f2c398cb1fa28ee7cb9c2a627 Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 22 May 2026 08:39:26 -0500 Subject: [PATCH 09/15] Fix race conditions and consolidate grace poll constant --- .../databricks/operators/databricks.py | 23 ++++-- .../operators/databricks_workflow.py | 22 +++++- .../databricks/triggers/databricks.py | 52 ++++++++++++-- .../databricks/operators/test_databricks.py | 50 +++++++++++++ .../operators/test_databricks_workflow.py | 7 +- .../databricks/triggers/test_databricks.py | 71 +++++++++++++++++-- 6 files changed, 206 insertions(+), 19 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 65eaae865faf0..84972690cf215 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -43,6 +43,7 @@ store_databricks_job_run_link, ) from airflow.providers.databricks.triggers.databricks import ( + WORKFLOW_REPAIR_GRACE_POLLS, DatabricksExecutionTrigger, DatabricksWorkflowRepairWaitTrigger, ) @@ -1640,12 +1641,15 @@ def _sync_wait_for_new_sub_run_attempt( self.databricks_task_key, ) polling_period_seconds = tg.workflow_repair_polling_period - terminal_grace_polls = 3 terminal_observations = 0 + last_repair_history_count: int | None = None while True: run_info = self._hook.get_run(self.databricks_run_id) # type: ignore[arg-type] parent_run_state = RunState(**run_info["state"]) tasks = run_info.get("tasks", []) + repair_history_count = len(run_info.get("repair_history", [])) + if last_repair_history_count is None: + last_repair_history_count = repair_history_count new_attempt = find_new_workflow_task_attempt( tasks=tasks, task_key=self.databricks_task_key, @@ -1654,16 +1658,27 @@ def _sync_wait_for_new_sub_run_attempt( ) if new_attempt is not None: return new_attempt["run_id"] - if parent_run_state.is_terminal: + if repair_history_count > last_repair_history_count: + self.log.info( + "Parent run %s repair_history grew (was %s, now %s); resetting grace counter " + "while waiting for a new attempt for task_key %s.", + self.databricks_run_id, + last_repair_history_count, + repair_history_count, + self.databricks_task_key, + ) + last_repair_history_count = repair_history_count + terminal_observations = 0 + elif parent_run_state.is_terminal: terminal_observations += 1 - if terminal_observations >= terminal_grace_polls: + if terminal_observations >= WORKFLOW_REPAIR_GRACE_POLLS: self.log.info( "Parent run %s reached terminal state %s without a new attempt for " "task_key %s after %s grace polls.", self.databricks_run_id, parent_run_state.result_state, self.databricks_task_key, - terminal_grace_polls, + WORKFLOW_REPAIR_GRACE_POLLS, ) return None else: diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index c30084b1aef91..536957c6282c3 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -34,7 +34,10 @@ WorkflowJobRunLink, store_databricks_job_run_link, ) -from airflow.providers.databricks.triggers.databricks import DatabricksWorkflowRepairCoordinatorTrigger +from airflow.providers.databricks.triggers.databricks import ( + WORKFLOW_REPAIR_GRACE_POLLS, + DatabricksWorkflowRepairCoordinatorTrigger, +) from airflow.providers.databricks.utils.databricks import build_repair_run_json, extract_failed_task_errors from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS @@ -446,6 +449,23 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: latest_repair_id, ) + # Wait for Databricks to reflect the repair (leave terminal state) before + # looping. Without this, the next get_run_state can return stale terminal + # state and trigger a second repair_run. + for _ in range(WORKFLOW_REPAIR_GRACE_POLLS): + time.sleep(2) + post_repair_state = self._hook.get_run_state(run_id) + if not post_repair_state.is_terminal: + break + else: + self.log.warning( + "Databricks run %s still reports terminal state after %s grace polls " + "following repair_id=%s; proceeding anyway.", + run_id, + WORKFLOW_REPAIR_GRACE_POLLS, + latest_repair_id, + ) + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: status = event.get("status") run_id = event["run_id"] diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 14fa6f14c53b6..1b0a462616ba1 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -30,6 +30,16 @@ from airflow.providers.databricks.utils.retry import validate_deferrable_databricks_retry_args from airflow.triggers.base import BaseTrigger, TriggerEvent +# Tolerate this many consecutive polls of stale Databricks state during the workflow-repair +# flow before drawing a conclusion. Covers two near-identical eventual-consistency windows: +# (1) after ``repair_run`` is accepted, ``runs/get`` can briefly continue to report the parent +# as TERMINATED before the repair transitions it to a non-terminal state — the coordinator +# polls up to this many times before declaring the repair issued; (2) when a sub-run reports +# terminal failure, a repair-triggered new sub-run can take a moment to appear in the parent's +# tasks list — the waiter polls up to this many times before declaring the parent terminally +# failed without a new attempt. +WORKFLOW_REPAIR_GRACE_POLLS = 3 + class DatabricksExecutionTrigger(BaseTrigger): """ @@ -288,6 +298,23 @@ async def run(self): new_repair_id, ) + # Wait for Databricks to reflect the repair (leave terminal state) before + # yielding. Without this, the next trigger cycle can observe stale terminal + # state and issue a second repair_run. + for _ in range(WORKFLOW_REPAIR_GRACE_POLLS): + await asyncio.sleep(2) + post_repair_state = await self.hook.a_get_run_state(self.run_id) + if not post_repair_state.is_terminal: + break + else: + self.log.warning( + "Databricks run %s still reports terminal state after %s grace polls " + "following repair_id=%s; proceeding anyway.", + self.run_id, + WORKFLOW_REPAIR_GRACE_POLLS, + new_repair_id, + ) + yield TriggerEvent( { "status": "repaired", @@ -380,16 +407,16 @@ def _find_new_attempt(self, tasks: list[dict[str, Any]]) -> dict[str, Any] | Non ) async def run(self): - # Grace polls before declaring the parent terminally failed without a new attempt: - # Databricks can briefly report the parent run terminal before a repair-triggered - # sub-run shows up in the tasks list. - terminal_grace_polls = 3 terminal_observations = 0 + last_repair_history_count: int | None = None async with self.hook: while True: run_info = await self.hook.a_get_run(self.run_id) run_state = RunState(**run_info["state"]) tasks = run_info.get("tasks", []) + repair_history_count = len(run_info.get("repair_history", [])) + if last_repair_history_count is None: + last_repair_history_count = repair_history_count new_attempt = self._find_new_attempt(tasks) if new_attempt is not None: @@ -411,7 +438,18 @@ async def run(self): ) return - if run_state.is_terminal and not run_state.is_successful: + if repair_history_count > last_repair_history_count: + self.log.info( + "Databricks workflow run %s repair_history grew (was %s, now %s); " + "resetting grace counter while waiting for a new attempt for task_key %s.", + self.run_id, + last_repair_history_count, + repair_history_count, + self.databricks_task_key, + ) + last_repair_history_count = repair_history_count + terminal_observations = 0 + elif run_state.is_terminal and not run_state.is_successful: terminal_observations += 1 self.log.info( "Databricks workflow run %s is in terminal failure state %s with no new " @@ -420,9 +458,9 @@ async def run(self): run_state.result_state, self.databricks_task_key, terminal_observations, - terminal_grace_polls, + WORKFLOW_REPAIR_GRACE_POLLS, ) - if terminal_observations >= terminal_grace_polls: + if terminal_observations >= WORKFLOW_REPAIR_GRACE_POLLS: yield TriggerEvent( { "status": "parent_failed", diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index d1fddfcf2f325..35b892a1e4a95 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -3172,3 +3172,53 @@ def test_sync_wait_for_new_sub_run_attempt_returns_new_attempt(self, mock_sleep) assert result == self.NEW_SUB_RUN_ID mock_sleep.assert_called_once_with(15) + + @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") + def test_sync_wait_for_new_sub_run_attempt_repair_history_growth_resets_grace(self, mock_sleep): + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + hook = MagicMock() + operator.__dict__["_hook"] = hook + operator.databricks_run_id = self.PARENT_RUN_ID + original_task = { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + } + terminal_state = {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": None} + # Two terminal polls, then a repair lands (history grows) — counter resets; + # then 3 more terminal polls with no further repair → returns None. + hook.get_run.side_effect = [ + {"state": terminal_state, "tasks": [original_task], "repair_history": []}, + {"state": terminal_state, "tasks": [original_task], "repair_history": []}, + { + "state": terminal_state, + "tasks": [original_task], + "repair_history": [{"id": 1, "type": "REPAIR"}], + }, + { + "state": terminal_state, + "tasks": [original_task], + "repair_history": [{"id": 1, "type": "REPAIR"}], + }, + { + "state": terminal_state, + "tasks": [original_task], + "repair_history": [{"id": 1, "type": "REPAIR"}], + }, + { + "state": terminal_state, + "tasks": [original_task], + "repair_history": [{"id": 1, "type": "REPAIR"}], + }, + ] + + result = operator._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + original_start_time=1000, + tg=operator.task_group, + ) + + # Without the reset, this would return after 3 polls. The reset on poll 3 delays the + # decision until 3 more terminal polls without further repair_history growth (poll 6). + assert result is None + assert hook.get_run.call_count == 6 diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 3a535211ac263..5cd20e53c0f8a 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -670,8 +670,12 @@ def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): operator = self._make_operator(workflow_repair_attempts=2, deferrable=False) hook = MagicMock() operator.__dict__["_hook"] = hook + # 1. top of outer loop: terminal+failed → repair_run + # 2. post-repair grace poll: non-terminal → break grace loop + # 3. top of outer loop: terminal+success → return hook.get_run_state.side_effect = [ RunState("TERMINATED", "FAILED", ""), + RunState("RUNNING", "", ""), RunState("TERMINATED", "SUCCESS", ""), ] hook.get_run.return_value = { @@ -692,7 +696,8 @@ def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): } ) assert result == {"run_id": 100, "repair_attempts": 1, "latest_repair_id": 555} - mock_sleep.assert_not_called() + # Grace loop slept once before observing non-terminal state. + assert mock_sleep.call_count == 1 class TestDatabricksWorkflowTaskGroupCoordinatorInjection: diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 707f74f759600..2ef83a0f0b845 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -503,18 +503,28 @@ async def test_emits_completed_when_run_succeeds(self, mock_get_run_state, mock_ assert events[0].payload["errors"] == [] @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") async def test_first_failure_within_budget_calls_repair_and_emits_repaired( - self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run + self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run, mock_sleep ): - mock_get_run_state.return_value = RunState( - life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, - state_message="", - result_state="FAILED", - ) + # First call: terminal+failed → trigger issues repair. + # Second call: post-repair grace poll → non-terminal → grace loop breaks. + mock_get_run_state.side_effect = [ + RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ), + RunState( + life_cycle_state="RUNNING", + state_message="", + result_state="", + ), + ] mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE mock_repair_run.return_value = 101 @@ -538,6 +548,8 @@ async def test_first_failure_within_budget_calls_repair_and_emits_repaired( assert repair_json["rerun_all_failed_tasks"] is True # First repair: latest_repair_id was None, so the field must be omitted from the payload assert "latest_repair_id" not in repair_json + # Grace loop slept once before observing non-terminal state. + assert mock_sleep.call_count == 1 @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") @@ -603,6 +615,7 @@ def _run_payload( result_state: str | None, life_cycle_state: str = LIFE_CYCLE_STATE_TERMINATED, tasks: list[dict] | None = None, + repair_history: list[dict] | None = None, ) -> dict: return { "run_page_url": RUN_PAGE_URL, @@ -612,6 +625,7 @@ def _run_payload( "result_state": result_state, }, "tasks": tasks or [], + "repair_history": repair_history or [], } def test_serialize_round_trips_state(self): @@ -683,3 +697,48 @@ async def test_emits_parent_failed_after_grace_polls(self, mock_get_run, mock_sl assert payload["parent_run_id"] == self.PARENT_RUN_ID assert payload["databricks_task_key"] == self.TASK_KEY assert RunState.from_json(payload["parent_run_state"]).result_state == "FAILED" + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_repair_history_growth_resets_grace_counter(self, mock_get_run, mock_sleep): + original_task = { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": self.TASK_KEY, + "start_time": 1000, + } + # Sequence: terminal twice, then a repair lands (history grows) — counter resets; + # then 3 more terminal polls with no further repair → parent_failed. + mock_get_run.side_effect = [ + self._run_payload(result_state="FAILED", tasks=[original_task], repair_history=[]), + self._run_payload(result_state="FAILED", tasks=[original_task], repair_history=[]), + self._run_payload( + result_state="FAILED", + tasks=[original_task], + repair_history=[{"id": 1, "type": "REPAIR"}], + ), + self._run_payload( + result_state="FAILED", + tasks=[original_task], + repair_history=[{"id": 1, "type": "REPAIR"}], + ), + self._run_payload( + result_state="FAILED", + tasks=[original_task], + repair_history=[{"id": 1, "type": "REPAIR"}], + ), + self._run_payload( + result_state="FAILED", + tasks=[original_task], + repair_history=[{"id": 1, "type": "REPAIR"}], + ), + ] + + trigger = self._make_trigger() + events = [event async for event in trigger.run()] + + # Without the reset, parent_failed would fire on the 3rd poll. The reset on poll 3 + # delays it until 3 more terminal polls have accumulated (poll 6). + assert mock_get_run.call_count == 6 + assert len(events) == 1 + assert events[0].payload["status"] == "parent_failed" From 1995a1804113239c6b4cacba5c7704ee912e1e5f Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 22 May 2026 08:57:55 -0500 Subject: [PATCH 10/15] Update rst --- providers/databricks/docs/operators/workflow.rst | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/providers/databricks/docs/operators/workflow.rst b/providers/databricks/docs/operators/workflow.rst index 18e0889b8ed7c..69f52288bafef 100644 --- a/providers/databricks/docs/operators/workflow.rst +++ b/providers/databricks/docs/operators/workflow.rst @@ -78,13 +78,12 @@ Databricks ``repair_run`` with ``rerun_all_failed_tasks=True``. Default is ``0`` .. code-block:: python - with DatabricksWorkflowTaskGroup( - group_id="example_databricks_workflow", + task_group = DatabricksWorkflowTaskGroup( + group_id="Example Workflow", databricks_conn_id="databricks_default", - job_clusters=job_clusters, workflow_repair_attempts=2, - ) as task_group: - ... + workflow_repair_polling_period=15, + ) Downstream task monitors stay in the same Airflow attempt across repairs, so size ``execution_timeout`` for the original run plus all repairs and set ``retries=0`` on workflow tasks. From 40550d9fce4f6a41819d99469063912a74fe2643 Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 22 May 2026 11:24:45 -0500 Subject: [PATCH 11/15] rework cordinator polling for repair --- .../databricks/docs/operators/workflow.rst | 7 +++ .../operators/databricks_workflow.py | 43 ++++++++++---- .../databricks/triggers/databricks.py | 50 ++++++++++------- .../operators/test_databricks_workflow.py | 37 +++++++++++- .../databricks/triggers/test_databricks.py | 56 ++++++++++++++++++- 5 files changed, 160 insertions(+), 33 deletions(-) diff --git a/providers/databricks/docs/operators/workflow.rst b/providers/databricks/docs/operators/workflow.rst index 69f52288bafef..60c387c9ee117 100644 --- a/providers/databricks/docs/operators/workflow.rst +++ b/providers/databricks/docs/operators/workflow.rst @@ -83,7 +83,14 @@ Databricks ``repair_run`` with ``rerun_all_failed_tasks=True``. Default is ``0`` databricks_conn_id="databricks_default", workflow_repair_attempts=2, workflow_repair_polling_period=15, + workflow_repair_reflection_timeout=300, ) +After ``repair_run`` is accepted, Databricks needs a moment to drop the parent run out of its +terminal state. The coordinator polls every ``workflow_repair_polling_period`` seconds and gives +Databricks up to ``workflow_repair_reflection_timeout`` seconds (default 300s / 5 minutes) to +reflect the repair before it fails the coordinator. Raise the timeout if your workspace is slow +to surface repaired runs. + Downstream task monitors stay in the same Airflow attempt across repairs, so size ``execution_timeout`` for the original run plus all repairs and set ``retries=0`` on workflow tasks. diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 536957c6282c3..d28500518e1f1 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -35,7 +35,6 @@ store_databricks_job_run_link, ) from airflow.providers.databricks.triggers.databricks import ( - WORKFLOW_REPAIR_GRACE_POLLS, DatabricksWorkflowRepairCoordinatorTrigger, ) from airflow.providers.databricks.utils.databricks import build_repair_run_json, extract_failed_task_errors @@ -319,6 +318,7 @@ def __init__( launch_task_id: str, workflow_repair_attempts: int, workflow_repair_polling_period: int = 30, + workflow_repair_reflection_timeout: int = 300, databricks_retry_limit: int = 3, databricks_retry_delay: int = 10, databricks_retry_args: dict[Any, Any] | None = None, @@ -334,6 +334,7 @@ def __init__( self.launch_task_id = launch_task_id self.workflow_repair_attempts = workflow_repair_attempts self.workflow_repair_polling_period = workflow_repair_polling_period + self.workflow_repair_reflection_timeout = workflow_repair_reflection_timeout self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args @@ -362,6 +363,7 @@ def _make_trigger( repair_attempts=repair_attempts, latest_repair_id=latest_repair_id, polling_period_seconds=self.workflow_repair_polling_period, + workflow_repair_reflection_timeout=self.workflow_repair_reflection_timeout, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, @@ -451,20 +453,21 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: # Wait for Databricks to reflect the repair (leave terminal state) before # looping. Without this, the next get_run_state can return stale terminal - # state and trigger a second repair_run. - for _ in range(WORKFLOW_REPAIR_GRACE_POLLS): - time.sleep(2) + # state and trigger a second repair_run. Bound the wait so a stuck DBX + # doesn't pin a worker forever. + deadline = time.monotonic() + self.workflow_repair_reflection_timeout + while True: + time.sleep(self.workflow_repair_polling_period) post_repair_state = self._hook.get_run_state(run_id) if not post_repair_state.is_terminal: break - else: - self.log.warning( - "Databricks run %s still reports terminal state after %s grace polls " - "following repair_id=%s; proceeding anyway.", - run_id, - WORKFLOW_REPAIR_GRACE_POLLS, - latest_repair_id, - ) + if time.monotonic() >= deadline: + raise DatabricksWorkflowRepairError( + f"Databricks did not reflect repair_id={latest_repair_id} for run {run_id} " + f"within {self.workflow_repair_reflection_timeout}s " + f"(workflow_repair_reflection_timeout); aborting to avoid issuing a " + f"duplicate repair_run against stale terminal state." + ) def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: status = event.get("status") @@ -510,6 +513,14 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: f"Errors: {errors}" ) + if status == "repair_not_reflected": + raise DatabricksWorkflowRepairError( + f"Databricks did not reflect repair_id={latest_repair_id} for run {run_id} " + f"within {self.workflow_repair_reflection_timeout}s " + f"(workflow_repair_reflection_timeout); aborting to avoid issuing a " + f"duplicate repair_run against stale terminal state." + ) + raise DatabricksWorkflowRepairError( f"DatabricksWorkflowRepairCoordinatorTrigger emitted unexpected status {status!r}: {event}" ) @@ -559,6 +570,11 @@ class DatabricksWorkflowTaskGroup(TaskGroup): effect on Airflow 3+; ignored on Airflow 2.x. Defaults to ``0``. :param workflow_repair_polling_period: How often the repair coordinator polls the Databricks run state. Only used when ``workflow_repair_attempts > 0``. + :param workflow_repair_reflection_timeout: Seconds the coordinator waits after a + ``repair_run`` is accepted for the parent run to leave its terminal state before + giving up and failing. Covers Databricks-side eventual consistency on a slow + cluster. Defaults to 300 seconds (5 minutes). Only used when + ``workflow_repair_attempts > 0``. """ is_databricks = True @@ -578,6 +594,7 @@ def __init__( spark_submit_params: list | None = None, workflow_repair_attempts: int = 0, workflow_repair_polling_period: int = 30, + workflow_repair_reflection_timeout: int = 300, **kwargs, ): if workflow_repair_attempts < 0: @@ -595,6 +612,7 @@ def __init__( self.spark_submit_params = spark_submit_params or [] self.workflow_repair_attempts = workflow_repair_attempts self.workflow_repair_polling_period = workflow_repair_polling_period + self.workflow_repair_reflection_timeout = workflow_repair_reflection_timeout super().__init__(**kwargs) def __exit__( @@ -661,6 +679,7 @@ def __exit__( launch_task_id=create_databricks_workflow_task.task_id, workflow_repair_attempts=self.workflow_repair_attempts, workflow_repair_polling_period=self.workflow_repair_polling_period, + workflow_repair_reflection_timeout=self.workflow_repair_reflection_timeout, # Retrying the coordinator would re-enter execute() with repair_attempts=0 # and start the budget over. retries=0, diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 1b0a462616ba1..06cd664f1a4d9 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -30,14 +30,13 @@ from airflow.providers.databricks.utils.retry import validate_deferrable_databricks_retry_args from airflow.triggers.base import BaseTrigger, TriggerEvent -# Tolerate this many consecutive polls of stale Databricks state during the workflow-repair -# flow before drawing a conclusion. Covers two near-identical eventual-consistency windows: -# (1) after ``repair_run`` is accepted, ``runs/get`` can briefly continue to report the parent -# as TERMINATED before the repair transitions it to a non-terminal state — the coordinator -# polls up to this many times before declaring the repair issued; (2) when a sub-run reports -# terminal failure, a repair-triggered new sub-run can take a moment to appear in the parent's -# tasks list — the waiter polls up to this many times before declaring the parent terminally -# failed without a new attempt. +# Tolerate this many consecutive polls of stale Databricks state in +# ``DatabricksWorkflowRepairWaitTrigger``: when a sub-run reports terminal failure, a +# repair-triggered new sub-run can take a moment to appear in the parent's tasks list — the +# waiter polls up to this many times before declaring the parent terminally failed without a +# new attempt. The coordinator uses a configurable wall-clock timeout instead (see +# ``workflow_repair_reflection_timeout``), since the post-``repair_run`` eventual-consistency +# window can stretch into minutes when Databricks is slow. WORKFLOW_REPAIR_GRACE_POLLS = 3 @@ -153,6 +152,9 @@ class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): :param repair_attempts: Repair attempts already performed. :param latest_repair_id: Repair id of the most recent repair attempt. :param polling_period_seconds: How often to poll the run state. + :param workflow_repair_reflection_timeout: Seconds to wait after ``repair_run`` is accepted + for the parent run to leave its terminal state before giving up and failing the + coordinator. Defaults to 5 minutes. :param retry_limit: Hook retry limit for transient Databricks API failures. :param retry_delay: Hook retry delay (seconds). :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. @@ -168,6 +170,7 @@ def __init__( repair_attempts: int = 0, latest_repair_id: int | None = None, polling_period_seconds: int = 30, + workflow_repair_reflection_timeout: int = 300, retry_limit: int = 3, retry_delay: int = 10, retry_args: dict[Any, Any] | None = None, @@ -181,6 +184,7 @@ def __init__( self.repair_attempts = repair_attempts self.latest_repair_id = latest_repair_id self.polling_period_seconds = polling_period_seconds + self.workflow_repair_reflection_timeout = workflow_repair_reflection_timeout self.retry_limit = retry_limit self.retry_delay = retry_delay self.retry_args = retry_args @@ -204,6 +208,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "repair_attempts": self.repair_attempts, "latest_repair_id": self.latest_repair_id, "polling_period_seconds": self.polling_period_seconds, + "workflow_repair_reflection_timeout": self.workflow_repair_reflection_timeout, "retry_limit": self.retry_limit, "retry_delay": self.retry_delay, "retry_args": self.retry_args, @@ -300,20 +305,27 @@ async def run(self): # Wait for Databricks to reflect the repair (leave terminal state) before # yielding. Without this, the next trigger cycle can observe stale terminal - # state and issue a second repair_run. - for _ in range(WORKFLOW_REPAIR_GRACE_POLLS): - await asyncio.sleep(2) + # state and issue a second repair_run. Bound the wait so a stuck DBX doesn't + # pin the trigger forever. + deadline = time.monotonic() + self.workflow_repair_reflection_timeout + while True: + await asyncio.sleep(self.polling_period_seconds) post_repair_state = await self.hook.a_get_run_state(self.run_id) if not post_repair_state.is_terminal: break - else: - self.log.warning( - "Databricks run %s still reports terminal state after %s grace polls " - "following repair_id=%s; proceeding anyway.", - self.run_id, - WORKFLOW_REPAIR_GRACE_POLLS, - new_repair_id, - ) + if time.monotonic() >= deadline: + yield TriggerEvent( + { + "status": "repair_not_reflected", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts + 1, + "latest_repair_id": new_repair_id, + "errors": errors, + } + ) + return yield TriggerEvent( { diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 5cd20e53c0f8a..061a4ead6745f 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -671,7 +671,7 @@ def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): hook = MagicMock() operator.__dict__["_hook"] = hook # 1. top of outer loop: terminal+failed → repair_run - # 2. post-repair grace poll: non-terminal → break grace loop + # 2. reflection poll: non-terminal → break reflection loop # 3. top of outer loop: terminal+success → return hook.get_run_state.side_effect = [ RunState("TERMINATED", "FAILED", ""), @@ -699,6 +699,39 @@ def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): # Grace loop slept once before observing non-terminal state. assert mock_sleep.call_count == 1 + @patch("airflow.providers.databricks.operators.databricks_workflow.time.sleep") + def test_sync_run_raises_when_repair_not_reflected_within_timeout(self, mock_sleep): + operator = self._make_operator(workflow_repair_attempts=2, deferrable=False) + operator.workflow_repair_reflection_timeout = 0 + hook = MagicMock() + operator.__dict__["_hook"] = hook + # 1. outer loop: terminal+failed → repair_run + # 2. reflection poll: still terminal → wall-clock deadline trips → raise (no second repair_run). + # workflow_repair_reflection_timeout=0 means the first elapsed ``time.monotonic()`` call + # after the no-op sleep is past the deadline, so the loop bails out. + hook.get_run_state.side_effect = [ + RunState("TERMINATED", "FAILED", ""), + RunState("TERMINATED", "FAILED", ""), + ] + hook.get_run.return_value = { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, + "tasks": [], + "overriding_parameters": {}, + } + hook.repair_run.return_value = 555 + + with pytest.raises(AirflowException) as exc: + operator._run_sync(run_id=100) + + message = str(exc.value) + assert "did not reflect repair_id=555" in message + assert "run 100" in message + assert "workflow_repair_reflection_timeout" in message + # Only the original repair_run — the raise must prevent a duplicate. + hook.repair_run.assert_called_once() + # One reflection-loop sleep fired before the deadline check tripped. + assert mock_sleep.call_count == 1 + class TestDatabricksWorkflowTaskGroupCoordinatorInjection: @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Coordinator task is only injected on Airflow 3+") @@ -709,6 +742,7 @@ def test_workflow_repair_attempts_positive_injects_coordinator_with_launch_upstr databricks_conn_id="databricks_conn", workflow_repair_attempts=2, workflow_repair_polling_period=15, + workflow_repair_reflection_timeout=120, ) as tg: task = MagicMock(task_id="task1") task._convert_to_databricks_workflow_task = MagicMock(return_value={}) @@ -718,5 +752,6 @@ def test_workflow_repair_attempts_positive_injects_coordinator_with_launch_upstr assert isinstance(coordinator, _DatabricksWorkflowRepairCoordinatorOperator) assert coordinator.workflow_repair_attempts == 2 assert coordinator.workflow_repair_polling_period == 15 + assert coordinator.workflow_repair_reflection_timeout == 120 assert coordinator.launch_task_id == "wf.launch" assert "wf.launch" in coordinator.upstream_task_ids diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 2ef83a0f0b845..989c75318f236 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -458,6 +458,7 @@ def _make_trigger( workflow_repair_attempts: int = 2, repair_attempts: int = 0, latest_repair_id: int | None = None, + workflow_repair_reflection_timeout: int = 300, ) -> DatabricksWorkflowRepairCoordinatorTrigger: return DatabricksWorkflowRepairCoordinatorTrigger( run_id=RUN_ID, @@ -466,6 +467,7 @@ def _make_trigger( repair_attempts=repair_attempts, latest_repair_id=latest_repair_id, polling_period_seconds=POLLING_INTERVAL_SECONDS, + workflow_repair_reflection_timeout=workflow_repair_reflection_timeout, run_page_url=RUN_PAGE_URL, ) @@ -512,7 +514,7 @@ async def test_first_failure_within_budget_calls_repair_and_emits_repaired( self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run, mock_sleep ): # First call: terminal+failed → trigger issues repair. - # Second call: post-repair grace poll → non-terminal → grace loop breaks. + # Second call: reflection poll → non-terminal → reflection loop breaks. mock_get_run_state.side_effect = [ RunState( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, @@ -551,6 +553,58 @@ async def test_first_failure_within_budget_calls_repair_and_emits_repaired( # Grace loop slept once before observing non-terminal state. assert mock_sleep.call_count == 1 + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_emits_repair_not_reflected_when_reflection_timeout_elapses( + self, + mock_get_run_state, + mock_get_run, + mock_get_run_output, + mock_repair_run, + mock_sleep, + ): + # First call: terminal+failed → trigger issues repair. + # Reflection poll: still terminal → wall-clock deadline trips → yield repair_not_reflected. + # workflow_repair_reflection_timeout=0 means the deadline is "now"; the first elapsed + # ``time.monotonic()`` call after the no-op sleep is past it, so the loop bails out. + mock_get_run_state.side_effect = [ + RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ), + RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ), + ] + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + mock_repair_run.return_value = 101 + + trigger = self._make_trigger( + workflow_repair_attempts=2, + repair_attempts=0, + latest_repair_id=None, + workflow_repair_reflection_timeout=0, + ) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "repair_not_reflected" + assert events[0].payload["run_id"] == RUN_ID + assert events[0].payload["repair_attempts"] == 1 + assert events[0].payload["latest_repair_id"] == 101 + # Only the original repair_run — yielding must prevent a duplicate next cycle. + mock_repair_run.assert_called_once() + # One reflection-loop sleep fired before the deadline check tripped. + assert mock_sleep.call_count == 1 + @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") From 8d6d13f7ebd76ce46ef0d901fecaf6f3a887dd98 Mon Sep 17 00:00:00 2001 From: Nick Date: Fri, 22 May 2026 11:32:40 -0500 Subject: [PATCH 12/15] tweak new parameter name for simplicity --- .../databricks/docs/operators/workflow.rst | 4 ++-- .../operators/databricks_workflow.py | 24 +++++++++---------- .../databricks/triggers/databricks.py | 12 +++++----- .../operators/test_databricks_workflow.py | 10 ++++---- .../databricks/triggers/test_databricks.py | 8 +++---- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/providers/databricks/docs/operators/workflow.rst b/providers/databricks/docs/operators/workflow.rst index 60c387c9ee117..3dec52d2740b0 100644 --- a/providers/databricks/docs/operators/workflow.rst +++ b/providers/databricks/docs/operators/workflow.rst @@ -83,12 +83,12 @@ Databricks ``repair_run`` with ``rerun_all_failed_tasks=True``. Default is ``0`` databricks_conn_id="databricks_default", workflow_repair_attempts=2, workflow_repair_polling_period=15, - workflow_repair_reflection_timeout=300, + workflow_repair_timeout=300, ) After ``repair_run`` is accepted, Databricks needs a moment to drop the parent run out of its terminal state. The coordinator polls every ``workflow_repair_polling_period`` seconds and gives -Databricks up to ``workflow_repair_reflection_timeout`` seconds (default 300s / 5 minutes) to +Databricks up to ``workflow_repair_timeout`` seconds (default 300s / 5 minutes) to reflect the repair before it fails the coordinator. Raise the timeout if your workspace is slow to surface repaired runs. diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index d28500518e1f1..9f6fb332c6982 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -318,7 +318,7 @@ def __init__( launch_task_id: str, workflow_repair_attempts: int, workflow_repair_polling_period: int = 30, - workflow_repair_reflection_timeout: int = 300, + workflow_repair_timeout: int = 300, databricks_retry_limit: int = 3, databricks_retry_delay: int = 10, databricks_retry_args: dict[Any, Any] | None = None, @@ -334,7 +334,7 @@ def __init__( self.launch_task_id = launch_task_id self.workflow_repair_attempts = workflow_repair_attempts self.workflow_repair_polling_period = workflow_repair_polling_period - self.workflow_repair_reflection_timeout = workflow_repair_reflection_timeout + self.workflow_repair_timeout = workflow_repair_timeout self.databricks_retry_limit = databricks_retry_limit self.databricks_retry_delay = databricks_retry_delay self.databricks_retry_args = databricks_retry_args @@ -363,7 +363,7 @@ def _make_trigger( repair_attempts=repair_attempts, latest_repair_id=latest_repair_id, polling_period_seconds=self.workflow_repair_polling_period, - workflow_repair_reflection_timeout=self.workflow_repair_reflection_timeout, + workflow_repair_timeout=self.workflow_repair_timeout, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, @@ -455,7 +455,7 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: # looping. Without this, the next get_run_state can return stale terminal # state and trigger a second repair_run. Bound the wait so a stuck DBX # doesn't pin a worker forever. - deadline = time.monotonic() + self.workflow_repair_reflection_timeout + deadline = time.monotonic() + self.workflow_repair_timeout while True: time.sleep(self.workflow_repair_polling_period) post_repair_state = self._hook.get_run_state(run_id) @@ -464,8 +464,8 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: if time.monotonic() >= deadline: raise DatabricksWorkflowRepairError( f"Databricks did not reflect repair_id={latest_repair_id} for run {run_id} " - f"within {self.workflow_repair_reflection_timeout}s " - f"(workflow_repair_reflection_timeout); aborting to avoid issuing a " + f"within {self.workflow_repair_timeout}s " + f"(workflow_repair_timeout); aborting to avoid issuing a " f"duplicate repair_run against stale terminal state." ) @@ -516,8 +516,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: if status == "repair_not_reflected": raise DatabricksWorkflowRepairError( f"Databricks did not reflect repair_id={latest_repair_id} for run {run_id} " - f"within {self.workflow_repair_reflection_timeout}s " - f"(workflow_repair_reflection_timeout); aborting to avoid issuing a " + f"within {self.workflow_repair_timeout}s " + f"(workflow_repair_timeout); aborting to avoid issuing a " f"duplicate repair_run against stale terminal state." ) @@ -570,7 +570,7 @@ class DatabricksWorkflowTaskGroup(TaskGroup): effect on Airflow 3+; ignored on Airflow 2.x. Defaults to ``0``. :param workflow_repair_polling_period: How often the repair coordinator polls the Databricks run state. Only used when ``workflow_repair_attempts > 0``. - :param workflow_repair_reflection_timeout: Seconds the coordinator waits after a + :param workflow_repair_timeout: Seconds the coordinator waits after a ``repair_run`` is accepted for the parent run to leave its terminal state before giving up and failing. Covers Databricks-side eventual consistency on a slow cluster. Defaults to 300 seconds (5 minutes). Only used when @@ -594,7 +594,7 @@ def __init__( spark_submit_params: list | None = None, workflow_repair_attempts: int = 0, workflow_repair_polling_period: int = 30, - workflow_repair_reflection_timeout: int = 300, + workflow_repair_timeout: int = 300, **kwargs, ): if workflow_repair_attempts < 0: @@ -612,7 +612,7 @@ def __init__( self.spark_submit_params = spark_submit_params or [] self.workflow_repair_attempts = workflow_repair_attempts self.workflow_repair_polling_period = workflow_repair_polling_period - self.workflow_repair_reflection_timeout = workflow_repair_reflection_timeout + self.workflow_repair_timeout = workflow_repair_timeout super().__init__(**kwargs) def __exit__( @@ -679,7 +679,7 @@ def __exit__( launch_task_id=create_databricks_workflow_task.task_id, workflow_repair_attempts=self.workflow_repair_attempts, workflow_repair_polling_period=self.workflow_repair_polling_period, - workflow_repair_reflection_timeout=self.workflow_repair_reflection_timeout, + workflow_repair_timeout=self.workflow_repair_timeout, # Retrying the coordinator would re-enter execute() with repair_attempts=0 # and start the budget over. retries=0, diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 06cd664f1a4d9..fa7d5911fe70a 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -35,7 +35,7 @@ # repair-triggered new sub-run can take a moment to appear in the parent's tasks list — the # waiter polls up to this many times before declaring the parent terminally failed without a # new attempt. The coordinator uses a configurable wall-clock timeout instead (see -# ``workflow_repair_reflection_timeout``), since the post-``repair_run`` eventual-consistency +# ``workflow_repair_timeout``), since the post-``repair_run`` eventual-consistency # window can stretch into minutes when Databricks is slow. WORKFLOW_REPAIR_GRACE_POLLS = 3 @@ -152,7 +152,7 @@ class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): :param repair_attempts: Repair attempts already performed. :param latest_repair_id: Repair id of the most recent repair attempt. :param polling_period_seconds: How often to poll the run state. - :param workflow_repair_reflection_timeout: Seconds to wait after ``repair_run`` is accepted + :param workflow_repair_timeout: Seconds to wait after ``repair_run`` is accepted for the parent run to leave its terminal state before giving up and failing the coordinator. Defaults to 5 minutes. :param retry_limit: Hook retry limit for transient Databricks API failures. @@ -170,7 +170,7 @@ def __init__( repair_attempts: int = 0, latest_repair_id: int | None = None, polling_period_seconds: int = 30, - workflow_repair_reflection_timeout: int = 300, + workflow_repair_timeout: int = 300, retry_limit: int = 3, retry_delay: int = 10, retry_args: dict[Any, Any] | None = None, @@ -184,7 +184,7 @@ def __init__( self.repair_attempts = repair_attempts self.latest_repair_id = latest_repair_id self.polling_period_seconds = polling_period_seconds - self.workflow_repair_reflection_timeout = workflow_repair_reflection_timeout + self.workflow_repair_timeout = workflow_repair_timeout self.retry_limit = retry_limit self.retry_delay = retry_delay self.retry_args = retry_args @@ -208,7 +208,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "repair_attempts": self.repair_attempts, "latest_repair_id": self.latest_repair_id, "polling_period_seconds": self.polling_period_seconds, - "workflow_repair_reflection_timeout": self.workflow_repair_reflection_timeout, + "workflow_repair_timeout": self.workflow_repair_timeout, "retry_limit": self.retry_limit, "retry_delay": self.retry_delay, "retry_args": self.retry_args, @@ -307,7 +307,7 @@ async def run(self): # yielding. Without this, the next trigger cycle can observe stale terminal # state and issue a second repair_run. Bound the wait so a stuck DBX doesn't # pin the trigger forever. - deadline = time.monotonic() + self.workflow_repair_reflection_timeout + deadline = time.monotonic() + self.workflow_repair_timeout while True: await asyncio.sleep(self.polling_period_seconds) post_repair_state = await self.hook.a_get_run_state(self.run_id) diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 061a4ead6745f..e0ab5b339bfe9 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -702,12 +702,12 @@ def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): @patch("airflow.providers.databricks.operators.databricks_workflow.time.sleep") def test_sync_run_raises_when_repair_not_reflected_within_timeout(self, mock_sleep): operator = self._make_operator(workflow_repair_attempts=2, deferrable=False) - operator.workflow_repair_reflection_timeout = 0 + operator.workflow_repair_timeout = 0 hook = MagicMock() operator.__dict__["_hook"] = hook # 1. outer loop: terminal+failed → repair_run # 2. reflection poll: still terminal → wall-clock deadline trips → raise (no second repair_run). - # workflow_repair_reflection_timeout=0 means the first elapsed ``time.monotonic()`` call + # workflow_repair_timeout=0 means the first elapsed ``time.monotonic()`` call # after the no-op sleep is past the deadline, so the loop bails out. hook.get_run_state.side_effect = [ RunState("TERMINATED", "FAILED", ""), @@ -726,7 +726,7 @@ def test_sync_run_raises_when_repair_not_reflected_within_timeout(self, mock_sle message = str(exc.value) assert "did not reflect repair_id=555" in message assert "run 100" in message - assert "workflow_repair_reflection_timeout" in message + assert "workflow_repair_timeout" in message # Only the original repair_run — the raise must prevent a duplicate. hook.repair_run.assert_called_once() # One reflection-loop sleep fired before the deadline check tripped. @@ -742,7 +742,7 @@ def test_workflow_repair_attempts_positive_injects_coordinator_with_launch_upstr databricks_conn_id="databricks_conn", workflow_repair_attempts=2, workflow_repair_polling_period=15, - workflow_repair_reflection_timeout=120, + workflow_repair_timeout=120, ) as tg: task = MagicMock(task_id="task1") task._convert_to_databricks_workflow_task = MagicMock(return_value={}) @@ -752,6 +752,6 @@ def test_workflow_repair_attempts_positive_injects_coordinator_with_launch_upstr assert isinstance(coordinator, _DatabricksWorkflowRepairCoordinatorOperator) assert coordinator.workflow_repair_attempts == 2 assert coordinator.workflow_repair_polling_period == 15 - assert coordinator.workflow_repair_reflection_timeout == 120 + assert coordinator.workflow_repair_timeout == 120 assert coordinator.launch_task_id == "wf.launch" assert "wf.launch" in coordinator.upstream_task_ids diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 989c75318f236..9f258518ba6fd 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -458,7 +458,7 @@ def _make_trigger( workflow_repair_attempts: int = 2, repair_attempts: int = 0, latest_repair_id: int | None = None, - workflow_repair_reflection_timeout: int = 300, + workflow_repair_timeout: int = 300, ) -> DatabricksWorkflowRepairCoordinatorTrigger: return DatabricksWorkflowRepairCoordinatorTrigger( run_id=RUN_ID, @@ -467,7 +467,7 @@ def _make_trigger( repair_attempts=repair_attempts, latest_repair_id=latest_repair_id, polling_period_seconds=POLLING_INTERVAL_SECONDS, - workflow_repair_reflection_timeout=workflow_repair_reflection_timeout, + workflow_repair_timeout=workflow_repair_timeout, run_page_url=RUN_PAGE_URL, ) @@ -569,7 +569,7 @@ async def test_emits_repair_not_reflected_when_reflection_timeout_elapses( ): # First call: terminal+failed → trigger issues repair. # Reflection poll: still terminal → wall-clock deadline trips → yield repair_not_reflected. - # workflow_repair_reflection_timeout=0 means the deadline is "now"; the first elapsed + # workflow_repair_timeout=0 means the deadline is "now"; the first elapsed # ``time.monotonic()`` call after the no-op sleep is past it, so the loop bails out. mock_get_run_state.side_effect = [ RunState( @@ -591,7 +591,7 @@ async def test_emits_repair_not_reflected_when_reflection_timeout_elapses( workflow_repair_attempts=2, repair_attempts=0, latest_repair_id=None, - workflow_repair_reflection_timeout=0, + workflow_repair_timeout=0, ) events = [event async for event in trigger.run()] From d0dc257dc925bb05a567a84301405af5e50c1786 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 16 Jun 2026 08:53:30 -0500 Subject: [PATCH 13/15] Unify repair poll deadline to single clock --- .../providers/databricks/hooks/databricks.py | 14 +- .../databricks/operators/databricks.py | 37 ++--- .../operators/databricks_workflow.py | 44 +++-- .../databricks/triggers/databricks.py | 89 +++++----- .../providers/databricks/utils/databricks.py | 20 +++ .../databricks/operators/test_databricks.py | 98 +++++++---- .../operators/test_databricks_workflow.py | 28 ++-- .../databricks/triggers/test_databricks.py | 154 ++++++++++-------- .../unit/databricks/utils/test_databricks.py | 19 +++ 9 files changed, 297 insertions(+), 206 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py index 7bba763455235..e83f05f8f681a 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py @@ -560,25 +560,31 @@ def get_run_tasks(self, run_id: int) -> list[dict[str, Any]]: return all_tasks - def get_run(self, run_id: int) -> dict[str, Any]: + def get_run(self, run_id: int, include_history: bool = False) -> dict[str, Any]: """ Retrieve run information. :param run_id: id of the run + :param include_history: whether to include the run's ``repair_history`` in the response. :return: state of the run """ - json = {"run_id": run_id} + json: dict[str, Any] = {"run_id": run_id} + if include_history: + json["include_history"] = "true" response = self._do_api_call(GET_RUN_ENDPOINT, json) return response - async def a_get_run(self, run_id: int) -> dict[str, Any]: + async def a_get_run(self, run_id: int, include_history: bool = False) -> dict[str, Any]: """ Async version of `get_run`. :param run_id: id of the run + :param include_history: whether to include the run's ``repair_history`` in the response. :return: state of the run """ - json = {"run_id": run_id} + json: dict[str, Any] = {"run_id": run_id} + if include_history: + json["include_history"] = "true" response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) return response diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 84972690cf215..670677fd86fc7 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -43,11 +43,11 @@ store_databricks_job_run_link, ) from airflow.providers.databricks.triggers.databricks import ( - WORKFLOW_REPAIR_GRACE_POLLS, DatabricksExecutionTrigger, DatabricksWorkflowRepairWaitTrigger, ) from airflow.providers.databricks.utils.databricks import ( + compute_repair_deadline, extract_failed_task_errors, find_new_workflow_task_attempt, normalise_json_content, @@ -1641,15 +1641,14 @@ def _sync_wait_for_new_sub_run_attempt( self.databricks_task_key, ) polling_period_seconds = tg.workflow_repair_polling_period - terminal_observations = 0 - last_repair_history_count: int | None = None + workflow_repair_timeout = tg.workflow_repair_timeout + # Give-up deadline (epoch seconds), anchored to the run's end_time to match the + # coordinator. Reset when the run leaves terminal failure so a later failure restarts it. + repair_deadline: float | None = None while True: run_info = self._hook.get_run(self.databricks_run_id) # type: ignore[arg-type] parent_run_state = RunState(**run_info["state"]) tasks = run_info.get("tasks", []) - repair_history_count = len(run_info.get("repair_history", [])) - if last_repair_history_count is None: - last_repair_history_count = repair_history_count new_attempt = find_new_workflow_task_attempt( tasks=tasks, task_key=self.databricks_task_key, @@ -1658,31 +1657,22 @@ def _sync_wait_for_new_sub_run_attempt( ) if new_attempt is not None: return new_attempt["run_id"] - if repair_history_count > last_repair_history_count: - self.log.info( - "Parent run %s repair_history grew (was %s, now %s); resetting grace counter " - "while waiting for a new attempt for task_key %s.", - self.databricks_run_id, - last_repair_history_count, - repair_history_count, - self.databricks_task_key, - ) - last_repair_history_count = repair_history_count - terminal_observations = 0 - elif parent_run_state.is_terminal: - terminal_observations += 1 - if terminal_observations >= WORKFLOW_REPAIR_GRACE_POLLS: + if parent_run_state.is_terminal and not parent_run_state.is_successful: + if repair_deadline is None: + repair_deadline = compute_repair_deadline(run_info, workflow_repair_timeout) + if time.time() >= repair_deadline: self.log.info( "Parent run %s reached terminal state %s without a new attempt for " - "task_key %s after %s grace polls.", + "task_key %s within %ss (anchored to the run's terminal end_time).", self.databricks_run_id, parent_run_state.result_state, self.databricks_task_key, - WORKFLOW_REPAIR_GRACE_POLLS, + workflow_repair_timeout, ) return None else: - terminal_observations = 0 + # Not in terminal failure — clear the deadline so a later failure restarts it. + repair_deadline = None time.sleep(polling_period_seconds) def _defer_to_workflow_repair_wait( @@ -1705,6 +1695,7 @@ def _defer_to_workflow_repair_wait( original_sub_run_id=original_sub_run_id, original_start_time=original_start_time, polling_period_seconds=tg.workflow_repair_polling_period, + workflow_repair_timeout=tg.workflow_repair_timeout, retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 9f6fb332c6982..345351cfd5ca4 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -28,7 +28,7 @@ from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, TaskGroup, conf from airflow.providers.databricks.exceptions import DatabricksWorkflowRepairError -from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState +from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState from airflow.providers.databricks.plugins.databricks_workflow import ( WorkflowJobRepairAllFailedLink, WorkflowJobRunLink, @@ -37,7 +37,12 @@ from airflow.providers.databricks.triggers.databricks import ( DatabricksWorkflowRepairCoordinatorTrigger, ) -from airflow.providers.databricks.utils.databricks import build_repair_run_json, extract_failed_task_errors +from airflow.providers.databricks.utils.databricks import ( + build_repair_run_json, + compute_repair_deadline, + extract_failed_task_errors, + is_repair_reflected, +) from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -318,7 +323,7 @@ def __init__( launch_task_id: str, workflow_repair_attempts: int, workflow_repair_polling_period: int = 30, - workflow_repair_timeout: int = 300, + workflow_repair_timeout: int = 180, databricks_retry_limit: int = 3, databricks_retry_delay: int = 10, databricks_retry_args: dict[Any, Any] | None = None, @@ -451,17 +456,22 @@ def _run_sync(self, run_id: int) -> dict[str, Any]: latest_repair_id, ) - # Wait for Databricks to reflect the repair (leave terminal state) before - # looping. Without this, the next get_run_state can return stale terminal - # state and trigger a second repair_run. Bound the wait so a stuck DBX - # doesn't pin a worker forever. - deadline = time.monotonic() + self.workflow_repair_timeout + # Wait for the repair to be reflected before looping, else a stale terminal state + # triggers a second repair_run. Reflection is confirmed by the repair_id appearing in + # repair_history (monotonic, so a fast repair can't slip past polling) or by the run + # leaving its terminal state. The deadline shares the run's end_time anchor so + # coordinator and waiters give up together. + deadline = compute_repair_deadline(run_info, self.workflow_repair_timeout) while True: time.sleep(self.workflow_repair_polling_period) - post_repair_state = self._hook.get_run_state(run_id) - if not post_repair_state.is_terminal: + post_repair_info = self._hook.get_run(run_id, include_history=True) + post_repair_state = RunState(**post_repair_info["state"]) + if ( + is_repair_reflected(post_repair_info, latest_repair_id) + or not post_repair_state.is_terminal + ): break - if time.monotonic() >= deadline: + if time.time() >= deadline: raise DatabricksWorkflowRepairError( f"Databricks did not reflect repair_id={latest_repair_id} for run {run_id} " f"within {self.workflow_repair_timeout}s " @@ -570,11 +580,11 @@ class DatabricksWorkflowTaskGroup(TaskGroup): effect on Airflow 3+; ignored on Airflow 2.x. Defaults to ``0``. :param workflow_repair_polling_period: How often the repair coordinator polls the Databricks run state. Only used when ``workflow_repair_attempts > 0``. - :param workflow_repair_timeout: Seconds the coordinator waits after a - ``repair_run`` is accepted for the parent run to leave its terminal state before - giving up and failing. Covers Databricks-side eventual consistency on a slow - cluster. Defaults to 300 seconds (5 minutes). Only used when - ``workflow_repair_attempts > 0``. + :param workflow_repair_timeout: How long Databricks may take to reflect a repair, as a + wall-clock window anchored to the parent run's terminal ``end_time``. The coordinator and + the downstream waiters share this value and anchor, so they give up at the same instant and + a downstream task is never failed while a repair could still land. Defaults to 180 seconds. + Only used when ``workflow_repair_attempts > 0``. """ is_databricks = True @@ -594,7 +604,7 @@ def __init__( spark_submit_params: list | None = None, workflow_repair_attempts: int = 0, workflow_repair_polling_period: int = 30, - workflow_repair_timeout: int = 300, + workflow_repair_timeout: int = 180, **kwargs, ): if workflow_repair_attempts < 0: diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index fa7d5911fe70a..91855a8fadaf2 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -19,26 +19,20 @@ import asyncio import time +from datetime import datetime, timezone from typing import Any from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState from airflow.providers.databricks.utils.databricks import ( build_repair_run_json, + compute_repair_deadline, extract_failed_task_errors_async, find_new_workflow_task_attempt, + is_repair_reflected, ) from airflow.providers.databricks.utils.retry import validate_deferrable_databricks_retry_args from airflow.triggers.base import BaseTrigger, TriggerEvent -# Tolerate this many consecutive polls of stale Databricks state in -# ``DatabricksWorkflowRepairWaitTrigger``: when a sub-run reports terminal failure, a -# repair-triggered new sub-run can take a moment to appear in the parent's tasks list — the -# waiter polls up to this many times before declaring the parent terminally failed without a -# new attempt. The coordinator uses a configurable wall-clock timeout instead (see -# ``workflow_repair_timeout``), since the post-``repair_run`` eventual-consistency -# window can stretch into minutes when Databricks is slow. -WORKFLOW_REPAIR_GRACE_POLLS = 3 - class DatabricksExecutionTrigger(BaseTrigger): """ @@ -152,9 +146,9 @@ class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): :param repair_attempts: Repair attempts already performed. :param latest_repair_id: Repair id of the most recent repair attempt. :param polling_period_seconds: How often to poll the run state. - :param workflow_repair_timeout: Seconds to wait after ``repair_run`` is accepted - for the parent run to leave its terminal state before giving up and failing the - coordinator. Defaults to 5 minutes. + :param workflow_repair_timeout: How long Databricks may take to reflect a repair, as a + wall-clock window anchored to the parent run's terminal ``end_time``. The downstream + waiters share the same anchor and value so both sides give up together. Defaults to 180s. :param retry_limit: Hook retry limit for transient Databricks API failures. :param retry_delay: Hook retry delay (seconds). :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. @@ -170,7 +164,7 @@ def __init__( repair_attempts: int = 0, latest_repair_id: int | None = None, polling_period_seconds: int = 30, - workflow_repair_timeout: int = 300, + workflow_repair_timeout: int = 180, retry_limit: int = 3, retry_delay: int = 10, retry_args: dict[Any, Any] | None = None, @@ -303,17 +297,27 @@ async def run(self): new_repair_id, ) - # Wait for Databricks to reflect the repair (leave terminal state) before - # yielding. Without this, the next trigger cycle can observe stale terminal - # state and issue a second repair_run. Bound the wait so a stuck DBX doesn't - # pin the trigger forever. - deadline = time.monotonic() + self.workflow_repair_timeout + # Wait for repair to be reflected via repair_id in history or a run leaving terminal state. + # Deadline anchored to run's end_time so coordinator and waiters give up together. + deadline = compute_repair_deadline(run_info, self.workflow_repair_timeout) + self.log.info( + "Waiting up to %ss for run %s to reflect repair_id=%s; " + "giving up at %s (deadline anchored to the run's terminal end_time).", + self.workflow_repair_timeout, + self.run_id, + new_repair_id, + datetime.fromtimestamp(deadline, tz=timezone.utc).isoformat(), + ) while True: await asyncio.sleep(self.polling_period_seconds) - post_repair_state = await self.hook.a_get_run_state(self.run_id) - if not post_repair_state.is_terminal: + post_repair_info = await self.hook.a_get_run(self.run_id, include_history=True) + post_repair_state = RunState(**post_repair_info["state"]) + if ( + is_repair_reflected(post_repair_info, new_repair_id) + or not post_repair_state.is_terminal + ): break - if time.monotonic() >= deadline: + if time.time() >= deadline: yield TriggerEvent( { "status": "repair_not_reflected", @@ -351,6 +355,9 @@ class DatabricksWorkflowRepairWaitTrigger(BaseTrigger): :param original_sub_run_id: The sub-run id of the attempt that just failed; the trigger only yields ``new_attempt`` for a sub-run id different from this one. :param polling_period_seconds: How often to poll the parent run. + :param workflow_repair_timeout: How long Databricks may take to reflect a repair, as a + wall-clock window anchored to the parent run's terminal ``end_time``. Must match the + coordinator's value; the waiter never declares ``parent_failed`` before that shared deadline. :param retry_limit: Hook retry limit for transient Databricks API failures. :param retry_delay: Hook retry delay (seconds). :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. @@ -366,6 +373,7 @@ def __init__( original_sub_run_id: int, original_start_time: int | None = None, polling_period_seconds: int = 30, + workflow_repair_timeout: int = 180, retry_limit: int = 3, retry_delay: int = 10, retry_args: dict[Any, Any] | None = None, @@ -379,6 +387,7 @@ def __init__( self.original_sub_run_id = original_sub_run_id self.original_start_time = original_start_time self.polling_period_seconds = polling_period_seconds + self.workflow_repair_timeout = workflow_repair_timeout self.retry_limit = retry_limit self.retry_delay = retry_delay self.retry_args = retry_args @@ -402,6 +411,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "original_sub_run_id": self.original_sub_run_id, "original_start_time": self.original_start_time, "polling_period_seconds": self.polling_period_seconds, + "workflow_repair_timeout": self.workflow_repair_timeout, "retry_limit": self.retry_limit, "retry_delay": self.retry_delay, "retry_args": self.retry_args, @@ -419,16 +429,14 @@ def _find_new_attempt(self, tasks: list[dict[str, Any]]) -> dict[str, Any] | Non ) async def run(self): - terminal_observations = 0 - last_repair_history_count: int | None = None + # Give-up deadline (epoch seconds), anchored to the run's end_time to match the + # coordinator. Reset when the run leaves terminal failure so a later failure restarts it. + repair_deadline: float | None = None async with self.hook: while True: run_info = await self.hook.a_get_run(self.run_id) run_state = RunState(**run_info["state"]) tasks = run_info.get("tasks", []) - repair_history_count = len(run_info.get("repair_history", [])) - if last_repair_history_count is None: - last_repair_history_count = repair_history_count new_attempt = self._find_new_attempt(tasks) if new_attempt is not None: @@ -450,29 +458,21 @@ async def run(self): ) return - if repair_history_count > last_repair_history_count: - self.log.info( - "Databricks workflow run %s repair_history grew (was %s, now %s); " - "resetting grace counter while waiting for a new attempt for task_key %s.", - self.run_id, - last_repair_history_count, - repair_history_count, - self.databricks_task_key, - ) - last_repair_history_count = repair_history_count - terminal_observations = 0 - elif run_state.is_terminal and not run_state.is_successful: - terminal_observations += 1 + if run_state.is_terminal and not run_state.is_successful: + if repair_deadline is None: + repair_deadline = compute_repair_deadline(run_info, self.workflow_repair_timeout) self.log.info( "Databricks workflow run %s is in terminal failure state %s with no new " - "attempt for task_key %s (grace %s of %s).", + "attempt for task_key %s; waiting up to %ss for the coordinator's repair to " + "be reflected, giving up at %s (deadline anchored to the run's terminal " + "end_time).", self.run_id, run_state.result_state, self.databricks_task_key, - terminal_observations, - WORKFLOW_REPAIR_GRACE_POLLS, + self.workflow_repair_timeout, + datetime.fromtimestamp(repair_deadline, tz=timezone.utc).isoformat(), ) - if terminal_observations >= WORKFLOW_REPAIR_GRACE_POLLS: + if time.time() >= repair_deadline: yield TriggerEvent( { "status": "parent_failed", @@ -484,7 +484,8 @@ async def run(self): ) return else: - terminal_observations = 0 + # Not in terminal failure — clear the deadline so a later failure restarts it. + repair_deadline = None await asyncio.sleep(self.polling_period_seconds) diff --git a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py index 0607ca0de3021..92426136439fb 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import time from typing import Any from airflow.providers.common.compat.sdk import AirflowException, XComArg @@ -126,6 +127,25 @@ def find_new_workflow_task_attempt( return max(candidates, key=lambda task: task.get("start_time", 0)) +def compute_repair_deadline(run_info: dict[str, Any], workflow_repair_timeout: int) -> float: + """ + Return the wall-clock deadline (epoch seconds) for a repair to be reflected. + + Anchored to the run's terminal ``end_time`` (epoch ms) so all waiters converge on the same + give-up instant. Falls back to now if ``end_time`` is not populated. + """ + end_time_ms = run_info.get("end_time") or 0 + anchor = end_time_ms / 1000 if end_time_ms else time.time() + return anchor + workflow_repair_timeout + + +def is_repair_reflected(run_info: dict[str, Any], repair_id: int | None) -> bool: + """Return ``True`` once ``repair_id`` appears in the run's ``repair_history``.""" + if repair_id is None: + return False + return any(entry.get("id") == repair_id for entry in run_info.get("repair_history", [])) + + def build_repair_run_json( run_id: int, latest_repair_id: int | None, diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 35b892a1e4a95..4b53dd15855e7 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -18,12 +18,14 @@ from __future__ import annotations import hashlib +import time from datetime import datetime, timedelta from typing import Any from unittest import mock from unittest.mock import MagicMock, call import pytest +import time_machine from tenacity import stop_after_attempt, wait_incrementing # Do not run the tests when FAB / Flask is not installed @@ -2998,6 +3000,9 @@ class TestExecuteCompleteWorkflowRepair: PARENT_RUN_ID = 100 ORIGINAL_SUB_RUN_ID = 500 NEW_SUB_RUN_ID = 700 + # Parent run terminal end_time (epoch ms) the sync waiter anchors its give-up deadline to. + PARENT_END_TIME_MS = 1_700_000_000_000 + ANCHOR_S = PARENT_END_TIME_MS / 1000 @staticmethod def _terminal_failure_event(sub_run_id: int) -> dict[str, Any]: @@ -3035,6 +3040,7 @@ def _operator_with_workflow_tg(self, workflow_repair_attempts: int) -> Databrick tg.is_databricks = True tg.workflow_repair_attempts = workflow_repair_attempts tg.workflow_repair_polling_period = 15 + tg.workflow_repair_timeout = 120 tg.task_group = None operator.task_group = tg return operator @@ -3054,6 +3060,7 @@ def test_execute_complete_failure_defers_on_wait_trigger_when_workflow_repair_at assert trigger.databricks_task_key == operator.databricks_task_key assert trigger.original_sub_run_id == self.ORIGINAL_SUB_RUN_ID assert trigger.polling_period_seconds == 15 + assert trigger.workflow_repair_timeout == 120 assert exc.value.method_name == "execute_complete_after_repair_wait" def test_execute_complete_failure_pulls_workflow_metadata_from_xcom(self): @@ -3070,6 +3077,7 @@ def test_execute_complete_failure_pulls_workflow_metadata_from_xcom(self): tg.is_databricks = True tg.workflow_repair_attempts = 2 tg.workflow_repair_polling_period = 15 + tg.workflow_repair_timeout = 120 tg.task_group = None operator.task_group = tg operator.upstream_task_ids = {"workflow.launch"} @@ -3174,7 +3182,8 @@ def test_sync_wait_for_new_sub_run_attempt_returns_new_attempt(self, mock_sleep) mock_sleep.assert_called_once_with(15) @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") - def test_sync_wait_for_new_sub_run_attempt_repair_history_growth_resets_grace(self, mock_sleep): + def test_sync_wait_for_new_sub_run_attempt_returns_none_after_repair_timeout(self, mock_sleep): + # Gives up only once the deadline (end_time + timeout) passes. 120s timeout, 15s polling = 8 sleeps. operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) hook = MagicMock() operator.__dict__["_hook"] = hook @@ -3185,40 +3194,61 @@ def test_sync_wait_for_new_sub_run_attempt_repair_history_growth_resets_grace(se "start_time": 1000, } terminal_state = {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": None} - # Two terminal polls, then a repair lands (history grows) — counter resets; - # then 3 more terminal polls with no further repair → returns None. + hook.get_run.return_value = { + "state": terminal_state, + "tasks": [original_task], + "end_time": self.PARENT_END_TIME_MS, + } + + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: + mock_sleep.side_effect = lambda _seconds: traveller.move_to(time.time() + 15) + result = operator._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + original_start_time=1000, + tg=operator.task_group, + ) + + assert result is None + # 8 sleeps advance the clock to the deadline; the 9th poll observes it has passed. + assert hook.get_run.call_count == 9 + assert mock_sleep.call_count == 8 + + @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") + def test_sync_wait_deadline_resets_when_run_leaves_terminal(self, mock_sleep): + # When the repair takes effect the run leaves terminal, resetting the deadline, and the + # waiter picks up the repaired attempt instead of returning None. + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + hook = MagicMock() + operator.__dict__["_hook"] = hook + operator.databricks_run_id = self.PARENT_RUN_ID + original_task = { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + } + new_task = { + "run_id": self.NEW_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 2000, + } + terminal_state = {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": None} + running_state = {"life_cycle_state": "RUNNING", "result_state": None, "state_message": None} hook.get_run.side_effect = [ - {"state": terminal_state, "tasks": [original_task], "repair_history": []}, - {"state": terminal_state, "tasks": [original_task], "repair_history": []}, - { - "state": terminal_state, - "tasks": [original_task], - "repair_history": [{"id": 1, "type": "REPAIR"}], - }, - { - "state": terminal_state, - "tasks": [original_task], - "repair_history": [{"id": 1, "type": "REPAIR"}], - }, - { - "state": terminal_state, - "tasks": [original_task], - "repair_history": [{"id": 1, "type": "REPAIR"}], - }, - { - "state": terminal_state, - "tasks": [original_task], - "repair_history": [{"id": 1, "type": "REPAIR"}], - }, + {"state": terminal_state, "tasks": [original_task], "end_time": self.PARENT_END_TIME_MS}, + {"state": terminal_state, "tasks": [original_task], "end_time": self.PARENT_END_TIME_MS}, + # Repair lands: the run is RUNNING again → deadline resets. + {"state": running_state, "tasks": [original_task]}, + # The repaired attempt for the watched task_key appears. + {"state": running_state, "tasks": [original_task, new_task]}, ] - result = operator._sync_wait_for_new_sub_run_attempt( - original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, - original_start_time=1000, - tg=operator.task_group, - ) + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: + mock_sleep.side_effect = lambda _seconds: traveller.move_to(time.time() + 15) + result = operator._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + original_start_time=1000, + tg=operator.task_group, + ) - # Without the reset, this would return after 3 polls. The reset on poll 3 delays the - # decision until 3 more terminal polls without further repair_history growth (poll 6). - assert result is None - assert hook.get_run.call_count == 6 + assert result == self.NEW_SUB_RUN_ID + assert hook.get_run.call_count == 4 diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index e0ab5b339bfe9..484dd42ffe7ac 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -671,18 +671,24 @@ def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): hook = MagicMock() operator.__dict__["_hook"] = hook # 1. top of outer loop: terminal+failed → repair_run - # 2. reflection poll: non-terminal → break reflection loop + # 2. reflection poll: run still terminal, but repair_id is in repair_history → reflected → + # break reflection loop (the fast-repair case a terminal-state-only check would miss) # 3. top of outer loop: terminal+success → return hook.get_run_state.side_effect = [ RunState("TERMINATED", "FAILED", ""), - RunState("RUNNING", "", ""), RunState("TERMINATED", "SUCCESS", ""), ] - hook.get_run.return_value = { - "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, - "tasks": [], - "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, - } + hook.get_run.side_effect = [ + { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, + "tasks": [], + "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, + }, + { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, + "repair_history": [{"id": 555}], + }, + ] hook.repair_run.return_value = 555 result = operator._run_sync(run_id=100) @@ -705,13 +711,11 @@ def test_sync_run_raises_when_repair_not_reflected_within_timeout(self, mock_sle operator.workflow_repair_timeout = 0 hook = MagicMock() operator.__dict__["_hook"] = hook - # 1. outer loop: terminal+failed → repair_run - # 2. reflection poll: still terminal → wall-clock deadline trips → raise (no second repair_run). - # workflow_repair_timeout=0 means the first elapsed ``time.monotonic()`` call - # after the no-op sleep is past the deadline, so the loop bails out. + # terminal+failed → repair_run; reflection poll still terminal with no repair_history entry + # → deadline trips → raise. + # timeout=0 makes the deadline the run's end_time, already past, so the loop bails out. hook.get_run_state.side_effect = [ RunState("TERMINATED", "FAILED", ""), - RunState("TERMINATED", "FAILED", ""), ] hook.get_run.return_value = { "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 9f258518ba6fd..2f97b69c64ccb 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -17,9 +17,11 @@ # under the License. from __future__ import annotations +import time from unittest import mock import pytest +import time_machine from tenacity import stop_after_attempt, wait_incrementing from airflow.models import Connection @@ -513,21 +515,21 @@ async def test_emits_completed_when_run_succeeds(self, mock_get_run_state, mock_ async def test_first_failure_within_budget_calls_repair_and_emits_repaired( self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run, mock_sleep ): - # First call: terminal+failed → trigger issues repair. - # Second call: reflection poll → non-terminal → reflection loop breaks. + # Outer loop: terminal+failed → trigger issues repair. + # Reflection poll: run is still terminal, but repair_id is in repair_history → reflected → + # reflection loop breaks. This is the fast-repair case that a terminal-state-only check + # would miss. mock_get_run_state.side_effect = [ RunState( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED", ), - RunState( - life_cycle_state="RUNNING", - state_message="", - result_state="", - ), ] - mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + mock_get_run.side_effect = [ + GET_RUN_RESPONSE_TERMINATED_WITH_FAILED, + {**GET_RUN_RESPONSE_TERMINATED_WITH_FAILED, "repair_history": [{"id": 101}]}, + ] mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE mock_repair_run.return_value = 101 @@ -567,21 +569,15 @@ async def test_emits_repair_not_reflected_when_reflection_timeout_elapses( mock_repair_run, mock_sleep, ): - # First call: terminal+failed → trigger issues repair. - # Reflection poll: still terminal → wall-clock deadline trips → yield repair_not_reflected. - # workflow_repair_timeout=0 means the deadline is "now"; the first elapsed - # ``time.monotonic()`` call after the no-op sleep is past it, so the loop bails out. + # terminal+failed → repair; reflection poll still terminal with no repair_history entry → + # deadline trips → repair_not_reflected. + # timeout=0 makes the deadline the run's end_time, already past, so the loop bails out. mock_get_run_state.side_effect = [ RunState( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED", ), - RunState( - life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, - state_message="", - result_state="FAILED", - ), ] mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE @@ -640,6 +636,9 @@ class TestDatabricksWorkflowRepairWaitTrigger: TASK_KEY = "monitored_task" ORIGINAL_SUB_RUN_ID = 500 NEW_SUB_RUN_ID = 700 + # Parent run terminal end_time (epoch ms) the waiter anchors its give-up deadline to. + PARENT_END_TIME_MS = 1_700_000_000_000 + ANCHOR_S = PARENT_END_TIME_MS / 1000 @pytest.fixture(autouse=True) def setup_connections(self, create_connection_without_db): @@ -654,13 +653,14 @@ def setup_connections(self, create_connection_without_db): ) ) - def _make_trigger(self) -> DatabricksWorkflowRepairWaitTrigger: + def _make_trigger(self, workflow_repair_timeout: int = 180) -> DatabricksWorkflowRepairWaitTrigger: return DatabricksWorkflowRepairWaitTrigger( run_id=self.PARENT_RUN_ID, databricks_conn_id=DEFAULT_CONN_ID, databricks_task_key=self.TASK_KEY, original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, polling_period_seconds=POLLING_INTERVAL_SECONDS, + workflow_repair_timeout=workflow_repair_timeout, run_page_url=RUN_PAGE_URL, ) @@ -669,9 +669,9 @@ def _run_payload( result_state: str | None, life_cycle_state: str = LIFE_CYCLE_STATE_TERMINATED, tasks: list[dict] | None = None, - repair_history: list[dict] | None = None, + end_time: int | None = None, ) -> dict: - return { + payload = { "run_page_url": RUN_PAGE_URL, "state": { "life_cycle_state": life_cycle_state, @@ -679,8 +679,10 @@ def _run_payload( "result_state": result_state, }, "tasks": tasks or [], - "repair_history": repair_history or [], } + if end_time is not None: + payload["end_time"] = end_time + return payload def test_serialize_round_trips_state(self): trigger = self._make_trigger() @@ -724,27 +726,29 @@ async def test_emits_new_attempt_when_new_sub_run_appears(self, mock_get_run): } @pytest.mark.asyncio - @mock.patch("asyncio.sleep") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") - async def test_emits_parent_failed_after_grace_polls(self, mock_get_run, mock_sleep): - terminal_payload = self._run_payload( + async def test_emits_parent_failed_after_repair_timeout(self, mock_get_run): + # Gives up only once the deadline (end_time + timeout) passes, not after a fixed poll + # count. 180s timeout, 30s polling = 6 sleeps. + mock_get_run.return_value = self._run_payload( result_state="FAILED", life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, - tasks=[ - { - "run_id": self.ORIGINAL_SUB_RUN_ID, - "task_key": self.TASK_KEY, - "start_time": 1000, - }, - ], + tasks=[{"run_id": self.ORIGINAL_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 1000}], + end_time=self.PARENT_END_TIME_MS, ) - mock_get_run.return_value = terminal_payload - trigger = self._make_trigger() - events = [event async for event in trigger.run()] + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: + + def advance(_seconds): + traveller.move_to(time.time() + POLLING_INTERVAL_SECONDS) - assert mock_get_run.call_count == 3 - assert mock_sleep.call_count == 2 + with mock.patch("asyncio.sleep", side_effect=advance) as mock_sleep: + trigger = self._make_trigger(workflow_repair_timeout=180) + events = [event async for event in trigger.run()] + + # 6 sleeps advance the clock to the deadline; parent_failed fires on the 7th poll. + assert mock_get_run.call_count == 7 + assert mock_sleep.call_count == 6 assert len(events) == 1 payload = events[0].payload assert payload["status"] == "parent_failed" @@ -755,44 +759,50 @@ async def test_emits_parent_failed_after_grace_polls(self, mock_get_run, mock_sl @pytest.mark.asyncio @mock.patch("asyncio.sleep") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") - async def test_repair_history_growth_resets_grace_counter(self, mock_get_run, mock_sleep): - original_task = { - "run_id": self.ORIGINAL_SUB_RUN_ID, - "task_key": self.TASK_KEY, - "start_time": 1000, - } - # Sequence: terminal twice, then a repair lands (history grows) — counter resets; - # then 3 more terminal polls with no further repair → parent_failed. + async def test_parent_failed_fires_immediately_when_deadline_already_past(self, mock_get_run, mock_sleep): + # Deadline is anchored to end_time, not to when polling started: a run that went terminal + # long ago is already past its deadline, so the first poll fails without waiting. + mock_get_run.return_value = self._run_payload( + result_state="FAILED", + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + tasks=[{"run_id": self.ORIGINAL_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 1000}], + end_time=self.PARENT_END_TIME_MS, + ) + + with time_machine.travel(self.ANCHOR_S + 1000, tick=False): + trigger = self._make_trigger(workflow_repair_timeout=180) + events = [event async for event in trigger.run()] + + assert mock_get_run.call_count == 1 + assert mock_sleep.call_count == 0 + assert events[0].payload["status"] == "parent_failed" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_deadline_resets_when_run_leaves_terminal(self, mock_get_run): + # When the repair takes effect the run leaves terminal, resetting the deadline, and the + # waiter picks up the repaired attempt instead of failing. + original_task = {"run_id": self.ORIGINAL_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 1000} + new_task = {"run_id": self.NEW_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 2000} mock_get_run.side_effect = [ - self._run_payload(result_state="FAILED", tasks=[original_task], repair_history=[]), - self._run_payload(result_state="FAILED", tasks=[original_task], repair_history=[]), - self._run_payload( - result_state="FAILED", - tasks=[original_task], - repair_history=[{"id": 1, "type": "REPAIR"}], - ), - self._run_payload( - result_state="FAILED", - tasks=[original_task], - repair_history=[{"id": 1, "type": "REPAIR"}], - ), - self._run_payload( - result_state="FAILED", - tasks=[original_task], - repair_history=[{"id": 1, "type": "REPAIR"}], - ), - self._run_payload( - result_state="FAILED", - tasks=[original_task], - repair_history=[{"id": 1, "type": "REPAIR"}], - ), + self._run_payload(result_state="FAILED", tasks=[original_task], end_time=self.PARENT_END_TIME_MS), + self._run_payload(result_state="FAILED", tasks=[original_task], end_time=self.PARENT_END_TIME_MS), + # Repair lands: the run is RUNNING again → deadline resets. + self._run_payload(result_state=None, life_cycle_state="RUNNING", tasks=[original_task]), + # The repaired attempt for the watched task_key appears. + self._run_payload(result_state=None, life_cycle_state="RUNNING", tasks=[original_task, new_task]), ] - trigger = self._make_trigger() - events = [event async for event in trigger.run()] + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: - # Without the reset, parent_failed would fire on the 3rd poll. The reset on poll 3 - # delays it until 3 more terminal polls have accumulated (poll 6). - assert mock_get_run.call_count == 6 + def advance(_seconds): + traveller.move_to(time.time() + POLLING_INTERVAL_SECONDS) + + with mock.patch("asyncio.sleep", side_effect=advance): + trigger = self._make_trigger(workflow_repair_timeout=180) + events = [event async for event in trigger.run()] + + assert mock_get_run.call_count == 4 assert len(events) == 1 - assert events[0].payload["status"] == "parent_failed" + assert events[0].payload["status"] == "new_attempt" + assert events[0].payload["new_sub_run_id"] == self.NEW_SUB_RUN_ID diff --git a/providers/databricks/tests/unit/databricks/utils/test_databricks.py b/providers/databricks/tests/unit/databricks/utils/test_databricks.py index d09388b88495e..5de845b287844 100644 --- a/providers/databricks/tests/unit/databricks/utils/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/utils/test_databricks.py @@ -26,9 +26,11 @@ from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.utils.databricks import ( build_repair_run_json, + compute_repair_deadline, extract_failed_task_errors, extract_failed_task_errors_async, find_new_workflow_task_attempt, + is_repair_reflected, normalise_json_content, validate_trigger_event, ) @@ -133,6 +135,23 @@ def test_build_repair_run_json_includes_optional_fields_only_when_present(self): "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, } + def test_compute_repair_deadline_anchors_to_terminal_end_time(self): + # end_time (epoch ms) in seconds plus the timeout, independent of when polling started. + assert compute_repair_deadline({"end_time": 1_700_000_000_000}, 180) == 1_700_000_000.0 + 180 + + def test_is_repair_reflected_detects_repair_id_in_history(self): + run_info = {"repair_history": [{"id": 111}, {"id": 222}]} + # The repair is reflected as soon as its repair_id is in repair_history, even while the run + # is still terminal — a fast repair never has to be caught mid-flight. + assert is_repair_reflected(run_info, 222) is True + assert is_repair_reflected(run_info, 111) is True + + def test_is_repair_reflected_false_when_not_present_or_unknown(self): + # repair_id absent from history, history missing entirely, or no repair_id yet. + assert is_repair_reflected({"repair_history": [{"id": 111}]}, 999) is False + assert is_repair_reflected({}, 111) is False + assert is_repair_reflected({"repair_history": [{"id": 111}]}, None) is False + class TestExtractFailedTaskErrors: """Test cases for the extract_failed_task_errors utility function (synchronous version)""" From 45d91da5ca7a7b9e37a37d2ac0c0a7ac438dfea4 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 16 Jun 2026 09:33:00 -0500 Subject: [PATCH 14/15] Update find_new_workflow_task_attempt to require a start_time --- .../providers/databricks/utils/databricks.py | 13 ++++++++++--- .../unit/databricks/utils/test_databricks.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py index 92426136439fb..07fe1a9175511 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py @@ -114,17 +114,24 @@ def find_new_workflow_task_attempt( original_sub_run_id: int, original_start_time: int | None, ) -> dict[str, Any] | None: - """Return the newest task entry matching ``task_key`` that is not the original sub-run.""" + """ + Return the newest started attempt for ``task_key`` that is not the original sub-run. + + Requires a populated ``start_time`` so a not-yet-started attempt yields ``None`` (keep polling); + ``run_id`` exclusion is the hard guard, ``start_time`` ranks attempts and time-filters when + ``original_start_time`` is known. + """ candidates = [ task for task in tasks if task.get("task_key") == task_key and task.get("run_id") != original_sub_run_id - and (original_start_time is None or task.get("start_time", 0) > original_start_time) + and task.get("start_time") + and (original_start_time is None or task["start_time"] > original_start_time) ] if not candidates: return None - return max(candidates, key=lambda task: task.get("start_time", 0)) + return max(candidates, key=lambda task: task["start_time"]) def compute_repair_deadline(run_info: dict[str, Any], workflow_repair_timeout: int) -> float: diff --git a/providers/databricks/tests/unit/databricks/utils/test_databricks.py b/providers/databricks/tests/unit/databricks/utils/test_databricks.py index 5de845b287844..5cdb9c8aa988e 100644 --- a/providers/databricks/tests/unit/databricks/utils/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/utils/test_databricks.py @@ -116,6 +116,25 @@ def test_find_new_workflow_task_attempt_picks_newest_matching_attempt(self): assert result == {"run_id": 202, "task_key": TASK_KEY_1, "start_time": 2500} + def test_find_new_workflow_task_attempt_skips_not_yet_started_attempt(self): + # A repaired attempt that has not started yet (missing/0/None start_time) is not selected; + # the caller keeps polling until it starts. + tasks = [ + {"run_id": TASK_RUN_ID_1, "task_key": TASK_KEY_1, "start_time": 1000}, + {"run_id": 201, "task_key": TASK_KEY_1, "start_time": 0}, + {"run_id": 202, "task_key": TASK_KEY_1, "start_time": None}, + {"run_id": 203, "task_key": TASK_KEY_1}, + ] + + result = find_new_workflow_task_attempt( + tasks=tasks, + task_key=TASK_KEY_1, + original_sub_run_id=TASK_RUN_ID_1, + original_start_time=None, + ) + + assert result is None + def test_build_repair_run_json_includes_optional_fields_only_when_present(self): assert build_repair_run_json(run_id=RUN_ID, latest_repair_id=None) == { "run_id": RUN_ID, From b01e19cc630a7c161487425e1792fa6829b18349 Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 16 Jun 2026 10:45:44 -0500 Subject: [PATCH 15/15] fix test import error --- .../tests/unit/databricks/triggers/test_databricks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 2f97b69c64ccb..2dacb611c0f80 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -18,6 +18,7 @@ from __future__ import annotations import time +from typing import Any from unittest import mock import pytest @@ -671,7 +672,7 @@ def _run_payload( tasks: list[dict] | None = None, end_time: int | None = None, ) -> dict: - payload = { + payload: dict[str, Any] = { "run_page_url": RUN_PAGE_URL, "state": { "life_cycle_state": life_cycle_state,