diff --git a/providers/databricks/docs/operators/workflow.rst b/providers/databricks/docs/operators/workflow.rst index 42cb0b20f9666..3dec52d2740b0 100644 --- a/providers/databricks/docs/operators/workflow.rst +++ b/providers/databricks/docs/operators/workflow.rst @@ -68,3 +68,29 @@ To minimize update conflicts, we recommend that you keep parameters in the ``not ``DatabricksWorkflowTaskGroup`` and not in the ``DatabricksNotebookOperator`` whenever possible. This is because, tasks in the ``DatabricksWorkflowTaskGroup`` are passed in on the job trigger time and do not modify the job definition. + +Automatic repair (Airflow 3+) +----------------------------- + +Set ``workflow_repair_attempts=N`` to auto-repair a failed workflow run up to N times. The task +group injects a ``repair_coordinator`` sibling that waits for the run to terminate and then calls +Databricks ``repair_run`` with ``rerun_all_failed_tasks=True``. Default is ``0`` (off). + +.. code-block:: python + + task_group = DatabricksWorkflowTaskGroup( + group_id="Example Workflow", + databricks_conn_id="databricks_default", + workflow_repair_attempts=2, + workflow_repair_polling_period=15, + workflow_repair_timeout=300, + ) + +After ``repair_run`` is accepted, Databricks needs a moment to drop the parent run out of its +terminal state. The coordinator polls every ``workflow_repair_polling_period`` seconds and gives +Databricks up to ``workflow_repair_timeout`` seconds (default 300s / 5 minutes) to +reflect the repair before it fails the coordinator. Raise the timeout if your workspace is slow +to surface repaired runs. + +Downstream task monitors stay in the same Airflow attempt across repairs, so size +``execution_timeout`` for the original run plus all repairs and set ``retries=0`` on workflow tasks. diff --git a/providers/databricks/src/airflow/providers/databricks/exceptions.py b/providers/databricks/src/airflow/providers/databricks/exceptions.py index f384552a34a6e..85b519e83b948 100644 --- a/providers/databricks/src/airflow/providers/databricks/exceptions.py +++ b/providers/databricks/src/airflow/providers/databricks/exceptions.py @@ -30,3 +30,7 @@ class DatabricksSqlExecutionError(AirflowException): class DatabricksSqlExecutionTimeout(DatabricksSqlExecutionError): """Raised when a sql execution times out.""" + + +class DatabricksWorkflowRepairError(AirflowException): + """Raised when Databricks Workflow repair coordination fails.""" diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py index 7bba763455235..e83f05f8f681a 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py @@ -560,25 +560,31 @@ def get_run_tasks(self, run_id: int) -> list[dict[str, Any]]: return all_tasks - def get_run(self, run_id: int) -> dict[str, Any]: + def get_run(self, run_id: int, include_history: bool = False) -> dict[str, Any]: """ Retrieve run information. :param run_id: id of the run + :param include_history: whether to include the run's ``repair_history`` in the response. :return: state of the run """ - json = {"run_id": run_id} + json: dict[str, Any] = {"run_id": run_id} + if include_history: + json["include_history"] = "true" response = self._do_api_call(GET_RUN_ENDPOINT, json) return response - async def a_get_run(self, run_id: int) -> dict[str, Any]: + async def a_get_run(self, run_id: int, include_history: bool = False) -> dict[str, Any]: """ Async version of `get_run`. :param run_id: id of the run + :param include_history: whether to include the run's ``repair_history`` in the response. :return: state of the run """ - json = {"run_id": run_id} + json: dict[str, Any] = {"run_id": run_id} + if include_history: + json["include_history"] = "true" response = await self._a_do_api_call(GET_RUN_ENDPOINT, json) return response diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index 9898993d4147e..670677fd86fc7 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink, XCom, conf +from airflow.providers.databricks.exceptions import DatabricksWorkflowRepairError from airflow.providers.databricks.hooks.databricks import ( DatabricksHook, RunLifeCycleState, @@ -43,9 +44,12 @@ ) from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, + DatabricksWorkflowRepairWaitTrigger, ) from airflow.providers.databricks.utils.databricks import ( + compute_repair_deadline, extract_failed_task_errors, + find_new_workflow_task_attempt, normalise_json_content, validate_trigger_event, ) @@ -1491,19 +1495,64 @@ def monitor_databricks_job(self) -> None: self.databricks_task_key, run_state.life_cycle_state, ) - if self.deferrable and not run_state.is_terminal: - self.defer( - trigger=DatabricksExecutionTrigger( - run_id=current_task_run_id, - databricks_conn_id=self.databricks_conn_id, - polling_period_seconds=self.polling_period_seconds, - retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay, - retry_args=self.databricks_retry_args, - caller=self.caller, - ), - method_name=DEFER_METHOD_NAME, - ) + if self.deferrable: + if not run_state.is_terminal: + self.defer( + trigger=DatabricksExecutionTrigger( + run_id=current_task_run_id, + databricks_conn_id=self.databricks_conn_id, + polling_period_seconds=self.polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ), + method_name=DEFER_METHOD_NAME, + ) + elif not run_state.is_successful: + tg = self._workflow_task_group_with_repair() + if tg is not None: + self._defer_to_workflow_repair_wait( + original_sub_run_id=current_task_run_id, + original_start_time=run.get("start_time"), + tg=tg, + ) + else: + tg = self._workflow_task_group_with_repair() + if tg is not None: + while True: + while not run_state.is_terminal: + time.sleep(self.polling_period_seconds) + run = self._hook.get_run(current_task_run_id) + run_state = RunState(**run["state"]) + self.log.info( + "Current state of the databricks task %s is %s", + self.databricks_task_key, + run_state.life_cycle_state, + ) + if run_state.is_successful: + break + new_sub_run_id = self._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=current_task_run_id, + original_start_time=run.get("start_time"), + tg=tg, + ) + if new_sub_run_id is None: + break + self.log.info( + "Repair coordinator produced a new attempt for task_key %s (sub-run %s).", + self.databricks_task_key, + new_sub_run_id, + ) + current_task_run_id = new_sub_run_id + run = self._hook.get_run(current_task_run_id) + run_state = RunState(**run["state"]) + self.log.info( + "Current state of the databricks task %s is %s", + self.databricks_task_key, + run_state.life_cycle_state, + ) + while not run_state.is_terminal: time.sleep(self.polling_period_seconds) run = self._hook.get_run(current_task_run_id) @@ -1523,13 +1572,7 @@ def monitor_databricks_job(self) -> None: def execute(self, context: Context) -> None: """Execute the operator. Launch the job and monitor it if wait_for_termination is set to True.""" if self._databricks_workflow_task_group: - # If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched. - if not self.workflow_run_metadata: - launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch")) - self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id) - workflow_run_metadata = WorkflowRunMetadata(**self.workflow_run_metadata) - self.databricks_run_id = workflow_run_metadata.run_id - self.databricks_conn_id = workflow_run_metadata.conn_id + workflow_run_metadata = self._resolve_workflow_run_metadata(context) # Store operator links in XCom for Airflow 3 compatibility if AIRFLOW_V_3_0_PLUS: @@ -1544,11 +1587,158 @@ def execute(self, context: Context) -> None: if self.wait_for_termination: self.monitor_databricks_job() + def _resolve_workflow_run_metadata(self, context: Context | dict | None) -> WorkflowRunMetadata: + # Called from both execute() and execute_complete(): the deferrable resume path + # re-instantiates the operator, so attributes set in execute() are lost and we + # re-resolve from the templated workflow_run_metadata field. + if not self.workflow_run_metadata and context is not None: + launch_task_id = next((task for task in self.upstream_task_ids if task.endswith(".launch")), None) + ti = context.get("ti") if isinstance(context, dict) else context["ti"] + if launch_task_id is not None and ti is not None: + self.workflow_run_metadata = ti.xcom_pull(task_ids=launch_task_id) + if not self.workflow_run_metadata: + raise ValueError("workflow_run_metadata is required to resolve the parent workflow run") + workflow_run_metadata = WorkflowRunMetadata(**self.workflow_run_metadata) + self.databricks_run_id = workflow_run_metadata.run_id + self.databricks_conn_id = workflow_run_metadata.conn_id + return workflow_run_metadata + def execute_complete(self, context: dict | None, event: dict) -> None: run_state = RunState.from_json(event["run_state"]) errors = event.get("errors", []) + + if not run_state.is_successful: + tg = self._workflow_task_group_with_repair() + if tg is not None: + self._resolve_workflow_run_metadata(context) + self._defer_to_workflow_repair_wait( + original_sub_run_id=event["run_id"], + original_start_time=event.get("run_start_time"), + tg=tg, + ) + self._handle_terminal_run_state(run_state, errors) + def _workflow_task_group_with_repair(self) -> DatabricksWorkflowTaskGroup | None: + if not AIRFLOW_V_3_0_PLUS: + return None + tg = self._databricks_workflow_task_group + if tg is None or getattr(tg, "workflow_repair_attempts", 0) <= 0: + return None + return tg + + def _sync_wait_for_new_sub_run_attempt( + self, + original_sub_run_id: int, + original_start_time: int | None, + tg: DatabricksWorkflowTaskGroup, + ) -> int | None: + """Sync mirror of DatabricksWorkflowRepairWaitTrigger. Returns the new sub-run id or None.""" + self.log.info( + "Sub-run %s for task_key %s reached terminal failure; waiting for a repair " + "attempt issued by the repair coordinator.", + original_sub_run_id, + self.databricks_task_key, + ) + polling_period_seconds = tg.workflow_repair_polling_period + workflow_repair_timeout = tg.workflow_repair_timeout + # Give-up deadline (epoch seconds), anchored to the run's end_time to match the + # coordinator. Reset when the run leaves terminal failure so a later failure restarts it. + repair_deadline: float | None = None + while True: + run_info = self._hook.get_run(self.databricks_run_id) # type: ignore[arg-type] + parent_run_state = RunState(**run_info["state"]) + tasks = run_info.get("tasks", []) + new_attempt = find_new_workflow_task_attempt( + tasks=tasks, + task_key=self.databricks_task_key, + original_sub_run_id=original_sub_run_id, + original_start_time=original_start_time, + ) + if new_attempt is not None: + return new_attempt["run_id"] + if parent_run_state.is_terminal and not parent_run_state.is_successful: + if repair_deadline is None: + repair_deadline = compute_repair_deadline(run_info, workflow_repair_timeout) + if time.time() >= repair_deadline: + self.log.info( + "Parent run %s reached terminal state %s without a new attempt for " + "task_key %s within %ss (anchored to the run's terminal end_time).", + self.databricks_run_id, + parent_run_state.result_state, + self.databricks_task_key, + workflow_repair_timeout, + ) + return None + else: + # Not in terminal failure — clear the deadline so a later failure restarts it. + repair_deadline = None + time.sleep(polling_period_seconds) + + def _defer_to_workflow_repair_wait( + self, + original_sub_run_id: int, + original_start_time: int | None, + tg: DatabricksWorkflowTaskGroup, + ) -> None: + self.log.info( + "Sub-run %s for task_key %s reached terminal failure; deferring to wait for a repair " + "attempt issued by the repair coordinator.", + original_sub_run_id, + self.databricks_task_key, + ) + self.defer( + trigger=DatabricksWorkflowRepairWaitTrigger( + run_id=self.databricks_run_id, # type: ignore[arg-type] + databricks_conn_id=self.databricks_conn_id, + databricks_task_key=self.databricks_task_key, + original_sub_run_id=original_sub_run_id, + original_start_time=original_start_time, + polling_period_seconds=tg.workflow_repair_polling_period, + workflow_repair_timeout=tg.workflow_repair_timeout, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ), + method_name="execute_complete_after_repair_wait", + ) + + def execute_complete_after_repair_wait(self, context: dict | None, event: dict) -> None: + status = event.get("status") + if status == "new_attempt": + new_sub_run_id = event["new_sub_run_id"] + self.log.info( + "Repair coordinator produced a new attempt for task_key %s (sub-run %s); " + "deferring on a fresh DatabricksExecutionTrigger to monitor it.", + self.databricks_task_key, + new_sub_run_id, + ) + self.defer( + trigger=DatabricksExecutionTrigger( + run_id=new_sub_run_id, + databricks_conn_id=self.databricks_conn_id, + polling_period_seconds=self.polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ), + method_name=DEFER_METHOD_NAME, + ) + elif status == "parent_failed": + parent_state = RunState.from_json(event["parent_run_state"]) + raise DatabricksWorkflowRepairError( + f"Databricks workflow run {event.get('parent_run_id')} reached terminal failure " + f"({parent_state.result_state}) without producing a new attempt for task_key " + f"{self.databricks_task_key!r}; repair budget is exhausted or the coordinator " + f"did not issue a repair." + ) + else: + raise DatabricksWorkflowRepairError( + f"DatabricksWorkflowRepairWaitTrigger emitted unexpected status {status!r}: {event}" + ) + class DatabricksNotebookOperator(DatabricksTaskBaseOperator): """ diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 779c2fc9f1528..345351cfd5ca4 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -19,19 +19,30 @@ import json import time +import warnings from dataclasses import dataclass from functools import cached_property from typing import TYPE_CHECKING, Any from mergedeep import merge -from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, TaskGroup -from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState +from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, TaskGroup, conf +from airflow.providers.databricks.exceptions import DatabricksWorkflowRepairError +from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState from airflow.providers.databricks.plugins.databricks_workflow import ( WorkflowJobRepairAllFailedLink, WorkflowJobRunLink, store_databricks_job_run_link, ) +from airflow.providers.databricks.triggers.databricks import ( + DatabricksWorkflowRepairCoordinatorTrigger, +) +from airflow.providers.databricks.utils.databricks import ( + build_repair_run_json, + compute_repair_deadline, + extract_failed_task_errors, + is_repair_reflected, +) from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -300,6 +311,231 @@ def on_kill(self) -> None: ) +class _DatabricksWorkflowRepairCoordinatorOperator(BaseOperator): + """Watch a Databricks Workflow run and issue repairs after terminal failures.""" + + caller = "_DatabricksWorkflowRepairCoordinatorOperator" + + def __init__( + self, + task_id: str, + databricks_conn_id: str, + launch_task_id: str, + workflow_repair_attempts: int, + workflow_repair_polling_period: int = 30, + workflow_repair_timeout: int = 180, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 10, + databricks_retry_args: dict[Any, Any] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + if workflow_repair_attempts < 1: + raise ValueError( + f"workflow_repair_attempts must be >= 1 for the repair coordinator task, got {workflow_repair_attempts}" + ) + super().__init__(task_id=task_id, **kwargs) + self.databricks_conn_id = databricks_conn_id + self.launch_task_id = launch_task_id + self.workflow_repair_attempts = workflow_repair_attempts + self.workflow_repair_polling_period = workflow_repair_polling_period + self.workflow_repair_timeout = workflow_repair_timeout + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args + self.deferrable = deferrable + + @cached_property + def _hook(self) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ) + + def _make_trigger( + self, + run_id: int, + repair_attempts: int, + latest_repair_id: int | None, + ) -> DatabricksWorkflowRepairCoordinatorTrigger: + return DatabricksWorkflowRepairCoordinatorTrigger( + run_id=run_id, + databricks_conn_id=self.databricks_conn_id, + workflow_repair_attempts=self.workflow_repair_attempts, + repair_attempts=repair_attempts, + latest_repair_id=latest_repair_id, + polling_period_seconds=self.workflow_repair_polling_period, + workflow_repair_timeout=self.workflow_repair_timeout, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.caller, + ) + + def execute(self, context: Context) -> Any: + launch_value = context["ti"].xcom_pull(task_ids=self.launch_task_id) + if not launch_value: + raise DatabricksWorkflowRepairError( + f"Launch task {self.launch_task_id!r} did not publish workflow run metadata; " + "cannot coordinate repairs." + ) + metadata = WorkflowRunMetadata(**launch_value) + + if self.deferrable: + self.defer( + trigger=self._make_trigger( + run_id=metadata.run_id, + repair_attempts=0, + latest_repair_id=None, + ), + method_name="execute_complete", + ) + return None + return self._run_sync(metadata.run_id) + + def _run_sync(self, run_id: int) -> dict[str, Any]: + """Sync equivalent of :class:`DatabricksWorkflowRepairCoordinatorTrigger`'s state machine.""" + repair_attempts = 0 + latest_repair_id: int | None = None + while True: + run_state = self._hook.get_run_state(run_id) + while not run_state.is_terminal: + self.log.info( + "Databricks run %s in state %s. Sleeping for %s seconds.", + run_id, + run_state, + self.workflow_repair_polling_period, + ) + time.sleep(self.workflow_repair_polling_period) + run_state = self._hook.get_run_state(run_id) + + if run_state.is_successful: + self.log.info( + "Databricks workflow run %s completed (repair_attempts=%s).", + run_id, + repair_attempts, + ) + return { + "run_id": run_id, + "repair_attempts": repair_attempts, + "latest_repair_id": latest_repair_id, + } + + run_info = self._hook.get_run(run_id) + errors = extract_failed_task_errors(self._hook, run_info, run_state) + + if repair_attempts >= self.workflow_repair_attempts: + raise DatabricksWorkflowRepairError( + f"Databricks workflow run {run_id} failed after {repair_attempts} repair " + f"attempt(s); repair budget exhausted (workflow_repair_attempts={self.workflow_repair_attempts}). " + f"Errors: {errors}" + ) + + self.log.info( + "Databricks run %s reached terminal failure state %s. Repairing all failed " + "tasks (attempt %s of %s, latest_repair_id=%s).", + run_id, + run_state.result_state, + repair_attempts + 1, + self.workflow_repair_attempts, + latest_repair_id, + ) + repair_json = build_repair_run_json( + run_id=run_id, + latest_repair_id=latest_repair_id, + overriding_parameters=run_info.get("overriding_parameters"), + ) + latest_repair_id = self._hook.repair_run(repair_json) + repair_attempts += 1 + self.log.info( + "Databricks repair_run accepted for run %s; new repair_id=%s.", + run_id, + latest_repair_id, + ) + + # Wait for the repair to be reflected before looping, else a stale terminal state + # triggers a second repair_run. Reflection is confirmed by the repair_id appearing in + # repair_history (monotonic, so a fast repair can't slip past polling) or by the run + # leaving its terminal state. The deadline shares the run's end_time anchor so + # coordinator and waiters give up together. + deadline = compute_repair_deadline(run_info, self.workflow_repair_timeout) + while True: + time.sleep(self.workflow_repair_polling_period) + post_repair_info = self._hook.get_run(run_id, include_history=True) + post_repair_state = RunState(**post_repair_info["state"]) + if ( + is_repair_reflected(post_repair_info, latest_repair_id) + or not post_repair_state.is_terminal + ): + break + if time.time() >= deadline: + raise DatabricksWorkflowRepairError( + f"Databricks did not reflect repair_id={latest_repair_id} for run {run_id} " + f"within {self.workflow_repair_timeout}s " + f"(workflow_repair_timeout); aborting to avoid issuing a " + f"duplicate repair_run against stale terminal state." + ) + + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + status = event.get("status") + run_id = event["run_id"] + repair_attempts = event["repair_attempts"] + latest_repair_id = event.get("latest_repair_id") + + if status == "completed": + self.log.info( + "Databricks workflow run %s completed (repair_attempts=%s).", + run_id, + repair_attempts, + ) + return { + "run_id": run_id, + "repair_attempts": repair_attempts, + "latest_repair_id": latest_repair_id, + } + + if status == "repaired": + self.log.info( + "Databricks workflow run %s repaired (repair_attempts=%s, latest_repair_id=%s); " + "re-deferring to monitor the repaired run.", + run_id, + repair_attempts, + latest_repair_id, + ) + self.defer( + trigger=self._make_trigger( + run_id=run_id, + repair_attempts=repair_attempts, + latest_repair_id=latest_repair_id, + ), + method_name="execute_complete", + ) + return None + + if status == "failed": + errors = event.get("errors", []) + raise DatabricksWorkflowRepairError( + f"Databricks workflow run {run_id} failed after {repair_attempts} repair " + f"attempt(s); repair budget exhausted (workflow_repair_attempts={self.workflow_repair_attempts}). " + f"Errors: {errors}" + ) + + if status == "repair_not_reflected": + raise DatabricksWorkflowRepairError( + f"Databricks did not reflect repair_id={latest_repair_id} for run {run_id} " + f"within {self.workflow_repair_timeout}s " + f"(workflow_repair_timeout); aborting to avoid issuing a " + f"duplicate repair_run against stale terminal state." + ) + + raise DatabricksWorkflowRepairError( + f"DatabricksWorkflowRepairCoordinatorTrigger emitted unexpected status {status!r}: {event}" + ) + + class DatabricksWorkflowTaskGroup(TaskGroup): """ A task group that takes a list of tasks and creates a databricks workflow. @@ -338,6 +574,17 @@ class DatabricksWorkflowTaskGroup(TaskGroup): all python tasks in the workflow. :param spark_submit_params: A list of spark submit parameters to pass to the workflow. These parameters will be passed to all spark submit tasks. + :param workflow_repair_attempts: Maximum number of automatic ``rerun_all_failed_tasks`` repair attempts to + issue against the Databricks run when downstream tasks fail. Each repair reuses the + original job cluster. Set to ``0`` to disable auto-repair (current behavior). Only takes + effect on Airflow 3+; ignored on Airflow 2.x. Defaults to ``0``. + :param workflow_repair_polling_period: How often the repair coordinator polls the + Databricks run state. Only used when ``workflow_repair_attempts > 0``. + :param workflow_repair_timeout: How long Databricks may take to reflect a repair, as a + wall-clock window anchored to the parent run's terminal ``end_time``. The coordinator and + the downstream waiters share this value and anchor, so they give up at the same instant and + a downstream task is never failed while a repair could still land. Defaults to 180 seconds. + Only used when ``workflow_repair_attempts > 0``. """ is_databricks = True @@ -355,8 +602,13 @@ def __init__( notebook_params: dict | None = None, python_params: list | None = None, spark_submit_params: list | None = None, + workflow_repair_attempts: int = 0, + workflow_repair_polling_period: int = 30, + workflow_repair_timeout: int = 180, **kwargs, ): + if workflow_repair_attempts < 0: + raise ValueError(f"workflow_repair_attempts must be >= 0, got {workflow_repair_attempts}") self.databricks_conn_id = databricks_conn_id self.access_control_list = access_control_list self.existing_clusters = existing_clusters or [] @@ -368,6 +620,9 @@ def __init__( self.notebook_params = notebook_params or {} self.python_params = python_params or [] self.spark_submit_params = spark_submit_params or [] + self.workflow_repair_attempts = workflow_repair_attempts + self.workflow_repair_polling_period = workflow_repair_polling_period + self.workflow_repair_timeout = workflow_repair_timeout super().__init__(**kwargs) def __exit__( @@ -403,11 +658,42 @@ def __exit__( f"Task {task.task_id} does not support conversion to databricks workflow task." ) + if AIRFLOW_V_3_0_PLUS and self.workflow_repair_attempts > 0: + task_retries = getattr(task, "retries", 0) + if isinstance(task_retries, int) and task_retries > 0: + warnings.warn( + f"Task {task.task_id!r} in DatabricksWorkflowTaskGroup " + f"{self.group_id!r} has retries={task_retries} while " + f"workflow_repair_attempts={self.workflow_repair_attempts}. Databricks-side repair supersedes " + "task-level retries for sub-run failures: a failed sub-run defers the " + "monitor on the repair-wait trigger and resumes in the same Airflow " + "attempt. Retries on the monitor will not trigger additional repairs " + "and only add cost on non-repair-related transient failures. Consider " + "setting retries=0 on workflow tasks when workflow_repair_attempts > 0.", + UserWarning, + stacklevel=2, + ) task.workflow_run_metadata = create_databricks_workflow_task.output create_databricks_workflow_task.relevant_upstreams.append(task.task_id) create_databricks_workflow_task.add_task(task.task_id, task) for root_task in roots: root_task.set_upstream(create_databricks_workflow_task) + + if AIRFLOW_V_3_0_PLUS and self.workflow_repair_attempts > 0: + repair_coordinator_task = _DatabricksWorkflowRepairCoordinatorOperator( + dag=self.dag, + task_group=self, + task_id="repair_coordinator", + databricks_conn_id=self.databricks_conn_id, + launch_task_id=create_databricks_workflow_task.task_id, + workflow_repair_attempts=self.workflow_repair_attempts, + workflow_repair_polling_period=self.workflow_repair_polling_period, + workflow_repair_timeout=self.workflow_repair_timeout, + # Retrying the coordinator would re-enter execute() with repair_attempts=0 + # and start the budget over. + retries=0, + ) + repair_coordinator_task.set_upstream(create_databricks_workflow_task) finally: super().__exit__(_type, _value, _tb) diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 2bb626b911418..91855a8fadaf2 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -19,10 +19,17 @@ import asyncio import time +from datetime import datetime, timezone from typing import Any -from airflow.providers.databricks.hooks.databricks import DatabricksHook -from airflow.providers.databricks.utils.databricks import extract_failed_task_errors_async +from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState +from airflow.providers.databricks.utils.databricks import ( + build_repair_run_json, + compute_repair_deadline, + extract_failed_task_errors_async, + find_new_workflow_task_attempt, + is_repair_reflected, +) from airflow.providers.databricks.utils.retry import validate_deferrable_databricks_retry_args from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -121,6 +128,7 @@ async def run(self): "run_id": self.run_id, "run_page_url": self.run_page_url, "run_state": run_state.to_json(), + "run_start_time": run_info.get("start_time"), "repair_run": self.repair_run, "errors": failed_tasks, } @@ -128,6 +136,360 @@ async def run(self): return +class DatabricksWorkflowRepairCoordinatorTrigger(BaseTrigger): + """ + Coordinate parent-run polling and repairs for a Databricks Workflow run. + + :param run_id: The Databricks run id to coordinate. + :param databricks_conn_id: Airflow connection id for the Databricks hook. + :param workflow_repair_attempts: Total repair attempts allowed for this run. + :param repair_attempts: Repair attempts already performed. + :param latest_repair_id: Repair id of the most recent repair attempt. + :param polling_period_seconds: How often to poll the run state. + :param workflow_repair_timeout: How long Databricks may take to reflect a repair, as a + wall-clock window anchored to the parent run's terminal ``end_time``. The downstream + waiters share the same anchor and value so both sides give up together. Defaults to 180s. + :param retry_limit: Hook retry limit for transient Databricks API failures. + :param retry_delay: Hook retry delay (seconds). + :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. + :param run_page_url: The Databricks UI URL for this run, surfaced in events for logging. + :param caller: Caller label forwarded to the hook for diagnostics. + """ + + def __init__( + self, + run_id: int, + databricks_conn_id: str, + workflow_repair_attempts: int, + repair_attempts: int = 0, + latest_repair_id: int | None = None, + polling_period_seconds: int = 30, + workflow_repair_timeout: int = 180, + retry_limit: int = 3, + retry_delay: int = 10, + retry_args: dict[Any, Any] | None = None, + run_page_url: str | None = None, + caller: str = "DatabricksWorkflowRepairCoordinatorTrigger", + ) -> None: + super().__init__() + self.run_id = run_id + self.databricks_conn_id = databricks_conn_id + self.workflow_repair_attempts = workflow_repair_attempts + self.repair_attempts = repair_attempts + self.latest_repair_id = latest_repair_id + self.polling_period_seconds = polling_period_seconds + self.workflow_repair_timeout = workflow_repair_timeout + self.retry_limit = retry_limit + self.retry_delay = retry_delay + self.retry_args = retry_args + self.run_page_url = run_page_url + self.caller = caller + self.hook = DatabricksHook( + databricks_conn_id, + retry_limit=self.retry_limit, + retry_delay=self.retry_delay, + retry_args=retry_args, + caller=caller, + ) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairCoordinatorTrigger", + { + "run_id": self.run_id, + "databricks_conn_id": self.databricks_conn_id, + "workflow_repair_attempts": self.workflow_repair_attempts, + "repair_attempts": self.repair_attempts, + "latest_repair_id": self.latest_repair_id, + "polling_period_seconds": self.polling_period_seconds, + "workflow_repair_timeout": self.workflow_repair_timeout, + "retry_limit": self.retry_limit, + "retry_delay": self.retry_delay, + "retry_args": self.retry_args, + "run_page_url": self.run_page_url, + "caller": self.caller, + }, + ) + + async def on_kill(self) -> None: + """Cancel the Databricks run when the trigger is cancelled by a user action.""" + if self.run_id: + from asgiref.sync import sync_to_async + + self.log.info("Cancelling Databricks run %s.", self.run_id) + await sync_to_async(self.hook.cancel_run)(self.run_id) + + async def run(self): + from asgiref.sync import sync_to_async + + async with self.hook: + while True: + run_state = await self.hook.a_get_run_state(self.run_id) + if not run_state.is_terminal: + self.log.info( + "run-id %s in run state %s. sleeping for %s seconds", + self.run_id, + run_state, + self.polling_period_seconds, + ) + await asyncio.sleep(self.polling_period_seconds) + continue + + run_info = await self.hook.a_get_run(self.run_id) + errors = await extract_failed_task_errors_async(self.hook, run_info, run_state) + + if run_state.is_successful: + self.log.info("Databricks run %s completed successfully.", self.run_id) + yield TriggerEvent( + { + "status": "completed", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts, + "latest_repair_id": self.latest_repair_id, + "errors": errors, + } + ) + return + + if self.repair_attempts >= self.workflow_repair_attempts: + self.log.info( + "Databricks run %s reached terminal failure state %s and repair budget " + "is exhausted (workflow_repair_attempts=%s).", + self.run_id, + run_state.result_state, + self.workflow_repair_attempts, + ) + yield TriggerEvent( + { + "status": "failed", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts, + "latest_repair_id": self.latest_repair_id, + "errors": errors, + } + ) + return + + self.log.info( + "Databricks run %s reached terminal failure state %s. Repairing all failed " + "tasks (attempt %s of %s, latest_repair_id=%s).", + self.run_id, + run_state.result_state, + self.repair_attempts + 1, + self.workflow_repair_attempts, + self.latest_repair_id, + ) + + repair_json = build_repair_run_json( + run_id=self.run_id, + latest_repair_id=self.latest_repair_id, + overriding_parameters=run_info.get("overriding_parameters"), + ) + + new_repair_id = await sync_to_async(self.hook.repair_run)(repair_json) + self.log.info( + "Databricks repair_run accepted for run %s; new repair_id=%s.", + self.run_id, + new_repair_id, + ) + + # Wait for repair to be reflected via repair_id in history or a run leaving terminal state. + # Deadline anchored to run's end_time so coordinator and waiters give up together. + deadline = compute_repair_deadline(run_info, self.workflow_repair_timeout) + self.log.info( + "Waiting up to %ss for run %s to reflect repair_id=%s; " + "giving up at %s (deadline anchored to the run's terminal end_time).", + self.workflow_repair_timeout, + self.run_id, + new_repair_id, + datetime.fromtimestamp(deadline, tz=timezone.utc).isoformat(), + ) + while True: + await asyncio.sleep(self.polling_period_seconds) + post_repair_info = await self.hook.a_get_run(self.run_id, include_history=True) + post_repair_state = RunState(**post_repair_info["state"]) + if ( + is_repair_reflected(post_repair_info, new_repair_id) + or not post_repair_state.is_terminal + ): + break + if time.time() >= deadline: + yield TriggerEvent( + { + "status": "repair_not_reflected", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts + 1, + "latest_repair_id": new_repair_id, + "errors": errors, + } + ) + return + + yield TriggerEvent( + { + "status": "repaired", + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_attempts": self.repair_attempts + 1, + "latest_repair_id": new_repair_id, + "errors": errors, + } + ) + return + + +class DatabricksWorkflowRepairWaitTrigger(BaseTrigger): + """ + Wait for the next attempt of a Databricks Workflow task after its sub-run fails. + + :param run_id: Parent workflow run id (stable across repairs). + :param databricks_conn_id: Airflow connection id for the Databricks hook. + :param databricks_task_key: The ``task_key`` of the Databricks task to watch for a new attempt. + :param original_sub_run_id: The sub-run id of the attempt that just failed; the trigger only + yields ``new_attempt`` for a sub-run id different from this one. + :param polling_period_seconds: How often to poll the parent run. + :param workflow_repair_timeout: How long Databricks may take to reflect a repair, as a + wall-clock window anchored to the parent run's terminal ``end_time``. Must match the + coordinator's value; the waiter never declares ``parent_failed`` before that shared deadline. + :param retry_limit: Hook retry limit for transient Databricks API failures. + :param retry_delay: Hook retry delay (seconds). + :param retry_args: Optional tenacity ``Retrying`` kwargs forwarded to the hook. + :param run_page_url: The Databricks UI URL for the parent run, surfaced in events for logging. + :param caller: Caller label forwarded to the hook for diagnostics. + """ + + def __init__( + self, + run_id: int, + databricks_conn_id: str, + databricks_task_key: str, + original_sub_run_id: int, + original_start_time: int | None = None, + polling_period_seconds: int = 30, + workflow_repair_timeout: int = 180, + retry_limit: int = 3, + retry_delay: int = 10, + retry_args: dict[Any, Any] | None = None, + run_page_url: str | None = None, + caller: str = "DatabricksWorkflowRepairWaitTrigger", + ) -> None: + super().__init__() + self.run_id = run_id + self.databricks_conn_id = databricks_conn_id + self.databricks_task_key = databricks_task_key + self.original_sub_run_id = original_sub_run_id + self.original_start_time = original_start_time + self.polling_period_seconds = polling_period_seconds + self.workflow_repair_timeout = workflow_repair_timeout + self.retry_limit = retry_limit + self.retry_delay = retry_delay + self.retry_args = retry_args + self.run_page_url = run_page_url + self.caller = caller + self.hook = DatabricksHook( + databricks_conn_id, + retry_limit=self.retry_limit, + retry_delay=self.retry_delay, + retry_args=retry_args, + caller=caller, + ) + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger", + { + "run_id": self.run_id, + "databricks_conn_id": self.databricks_conn_id, + "databricks_task_key": self.databricks_task_key, + "original_sub_run_id": self.original_sub_run_id, + "original_start_time": self.original_start_time, + "polling_period_seconds": self.polling_period_seconds, + "workflow_repair_timeout": self.workflow_repair_timeout, + "retry_limit": self.retry_limit, + "retry_delay": self.retry_delay, + "retry_args": self.retry_args, + "run_page_url": self.run_page_url, + "caller": self.caller, + }, + ) + + def _find_new_attempt(self, tasks: list[dict[str, Any]]) -> dict[str, Any] | None: + return find_new_workflow_task_attempt( + tasks=tasks, + task_key=self.databricks_task_key, + original_sub_run_id=self.original_sub_run_id, + original_start_time=self.original_start_time, + ) + + async def run(self): + # Give-up deadline (epoch seconds), anchored to the run's end_time to match the + # coordinator. Reset when the run leaves terminal failure so a later failure restarts it. + repair_deadline: float | None = None + async with self.hook: + while True: + run_info = await self.hook.a_get_run(self.run_id) + run_state = RunState(**run_info["state"]) + tasks = run_info.get("tasks", []) + + new_attempt = self._find_new_attempt(tasks) + if new_attempt is not None: + self.log.info( + "Databricks workflow run %s produced a new attempt for task_key %s " + "(new sub-run id %s).", + self.run_id, + self.databricks_task_key, + new_attempt["run_id"], + ) + yield TriggerEvent( + { + "status": "new_attempt", + "parent_run_id": self.run_id, + "databricks_task_key": self.databricks_task_key, + "new_sub_run_id": new_attempt["run_id"], + "run_page_url": self.run_page_url, + } + ) + return + + if run_state.is_terminal and not run_state.is_successful: + if repair_deadline is None: + repair_deadline = compute_repair_deadline(run_info, self.workflow_repair_timeout) + self.log.info( + "Databricks workflow run %s is in terminal failure state %s with no new " + "attempt for task_key %s; waiting up to %ss for the coordinator's repair to " + "be reflected, giving up at %s (deadline anchored to the run's terminal " + "end_time).", + self.run_id, + run_state.result_state, + self.databricks_task_key, + self.workflow_repair_timeout, + datetime.fromtimestamp(repair_deadline, tz=timezone.utc).isoformat(), + ) + if time.time() >= repair_deadline: + yield TriggerEvent( + { + "status": "parent_failed", + "parent_run_id": self.run_id, + "databricks_task_key": self.databricks_task_key, + "parent_run_state": run_state.to_json(), + "run_page_url": self.run_page_url, + } + ) + return + else: + # Not in terminal failure — clear the deadline so a later failure restarts it. + repair_deadline = None + + await asyncio.sleep(self.polling_period_seconds) + + class DatabricksSQLStatementExecutionTrigger(BaseTrigger): """ The trigger handles the logic of async communication with DataBricks SQL Statements API. diff --git a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py index 05b6b17710ef8..07fe1a9175511 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/databricks.py @@ -17,6 +17,9 @@ # under the License. from __future__ import annotations +import time +from typing import Any + from airflow.providers.common.compat.sdk import AirflowException, XComArg from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState @@ -105,6 +108,69 @@ async def extract_failed_task_errors_async( return failed_tasks +def find_new_workflow_task_attempt( + tasks: list[dict[str, Any]], + task_key: str, + original_sub_run_id: int, + original_start_time: int | None, +) -> dict[str, Any] | None: + """ + Return the newest started attempt for ``task_key`` that is not the original sub-run. + + Requires a populated ``start_time`` so a not-yet-started attempt yields ``None`` (keep polling); + ``run_id`` exclusion is the hard guard, ``start_time`` ranks attempts and time-filters when + ``original_start_time`` is known. + """ + candidates = [ + task + for task in tasks + if task.get("task_key") == task_key + and task.get("run_id") != original_sub_run_id + and task.get("start_time") + and (original_start_time is None or task["start_time"] > original_start_time) + ] + if not candidates: + return None + return max(candidates, key=lambda task: task["start_time"]) + + +def compute_repair_deadline(run_info: dict[str, Any], workflow_repair_timeout: int) -> float: + """ + Return the wall-clock deadline (epoch seconds) for a repair to be reflected. + + Anchored to the run's terminal ``end_time`` (epoch ms) so all waiters converge on the same + give-up instant. Falls back to now if ``end_time`` is not populated. + """ + end_time_ms = run_info.get("end_time") or 0 + anchor = end_time_ms / 1000 if end_time_ms else time.time() + return anchor + workflow_repair_timeout + + +def is_repair_reflected(run_info: dict[str, Any], repair_id: int | None) -> bool: + """Return ``True`` once ``repair_id`` appears in the run's ``repair_history``.""" + if repair_id is None: + return False + return any(entry.get("id") == repair_id for entry in run_info.get("repair_history", [])) + + +def build_repair_run_json( + run_id: int, + latest_repair_id: int | None, + overriding_parameters: Any = None, +) -> dict[str, Any]: + """Build the ``DatabricksHook.repair_run`` payload for ``rerun_all_failed_tasks`` repair.""" + repair_json: dict[str, Any] = { + "run_id": run_id, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + } + if latest_repair_id is not None: + repair_json["latest_repair_id"] = latest_repair_id + if overriding_parameters is not None: + repair_json["overriding_parameters"] = overriding_parameters + return repair_json + + def validate_trigger_event(event: dict): """ Validate correctness of the event received from DatabricksExecutionTrigger. diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks.py b/providers/databricks/tests/unit/databricks/operators/test_databricks.py index 4684b14282c4e..4b53dd15855e7 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks.py @@ -18,12 +18,14 @@ from __future__ import annotations import hashlib +import time from datetime import datetime, timedelta from typing import Any from unittest import mock from unittest.mock import MagicMock, call import pytest +import time_machine from tenacity import stop_after_attempt, wait_incrementing # Do not run the tests when FAB / Flask is not installed @@ -49,9 +51,12 @@ from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, DatabricksSQLStatementExecutionTrigger, + DatabricksWorkflowRepairWaitTrigger, ) from airflow.providers.databricks.utils import databricks as utils +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + DATE = "2017-04-20" TASK_ID = "databricks-operator" DEFAULT_CONN_ID = "databricks_default" @@ -2988,3 +2993,262 @@ def test_user_databricks_task_key(self): expected_task_key = "test_task_key" assert expected_task_key == operator.databricks_task_key + + +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Workflow repair flow is Airflow 3+ only") +class TestExecuteCompleteWorkflowRepair: + PARENT_RUN_ID = 100 + ORIGINAL_SUB_RUN_ID = 500 + NEW_SUB_RUN_ID = 700 + # Parent run terminal end_time (epoch ms) the sync waiter anchors its give-up deadline to. + PARENT_END_TIME_MS = 1_700_000_000_000 + ANCHOR_S = PARENT_END_TIME_MS / 1000 + + @staticmethod + def _terminal_failure_event(sub_run_id: int) -> dict[str, Any]: + return { + "run_id": sub_run_id, + "run_page_url": "https://example", + "run_state": RunState( + life_cycle_state="TERMINATED", + state_message="boom", + result_state="FAILED", + ).to_json(), + "errors": [{"task_key": "tk", "run_id": sub_run_id, "error": "boom"}], + } + + def _operator_with_workflow_tg(self, workflow_repair_attempts: int) -> DatabricksNotebookOperator: + # Pass workflow_run_metadata (templated field) rather than setting + # databricks_run_id directly. Airflow re-instantiates the operator at deferrable + # resume time, so execute_complete must rederive the parent run id from the + # rendered template, not from any attribute set during execute(). + operator = DatabricksNotebookOperator( + task_id="task1", + notebook_path="path", + source="WORKSPACE", + databricks_conn_id="databricks_default", + workflow_run_metadata={ + "conn_id": "databricks_default", + "job_id": 1, + "run_id": self.PARENT_RUN_ID, + }, + ) + + # _databricks_workflow_task_group walks up `task_group` looking for is_databricks=True. + # Hand it a single-hop chain so it finds our mocked group directly. + tg = MagicMock() + tg.is_databricks = True + tg.workflow_repair_attempts = workflow_repair_attempts + tg.workflow_repair_polling_period = 15 + tg.workflow_repair_timeout = 120 + tg.task_group = None + operator.task_group = tg + return operator + + def test_execute_complete_failure_defers_on_wait_trigger_when_workflow_repair_attempts_set(self): + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete( + context=None, + event=self._terminal_failure_event(self.ORIGINAL_SUB_RUN_ID), + ) + + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairWaitTrigger) + assert trigger.run_id == self.PARENT_RUN_ID + assert trigger.databricks_task_key == operator.databricks_task_key + assert trigger.original_sub_run_id == self.ORIGINAL_SUB_RUN_ID + assert trigger.polling_period_seconds == 15 + assert trigger.workflow_repair_timeout == 120 + assert exc.value.method_name == "execute_complete_after_repair_wait" + + def test_execute_complete_failure_pulls_workflow_metadata_from_xcom(self): + # Mirrors the deferrable-resume path where execute() never ran (so + # workflow_run_metadata was not pre-populated): the resolver must fall back + # to xcom_pull from the upstream .launch task to recover the parent run id. + operator = DatabricksNotebookOperator( + task_id="task1", + notebook_path="path", + source="WORKSPACE", + databricks_conn_id="databricks_default", + ) + tg = MagicMock() + tg.is_databricks = True + tg.workflow_repair_attempts = 2 + tg.workflow_repair_polling_period = 15 + tg.workflow_repair_timeout = 120 + tg.task_group = None + operator.task_group = tg + operator.upstream_task_ids = {"workflow.launch"} + + ti = MagicMock() + ti.xcom_pull.return_value = { + "conn_id": "databricks_default", + "job_id": 1, + "run_id": self.PARENT_RUN_ID, + } + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete( + context={"ti": ti}, + event=self._terminal_failure_event(self.ORIGINAL_SUB_RUN_ID), + ) + + ti.xcom_pull.assert_called_once_with(task_ids="workflow.launch") + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairWaitTrigger) + assert trigger.run_id == self.PARENT_RUN_ID + + def test_execute_complete_after_repair_wait_new_attempt_defers_on_execution_trigger(self): + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete_after_repair_wait( + context=None, + event={ + "status": "new_attempt", + "parent_run_id": self.PARENT_RUN_ID, + "databricks_task_key": operator.databricks_task_key, + "new_sub_run_id": self.NEW_SUB_RUN_ID, + "run_page_url": "https://example", + }, + ) + + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksExecutionTrigger) + assert trigger.run_id == self.NEW_SUB_RUN_ID + assert exc.value.method_name == "execute_complete" + + def test_execute_complete_after_repair_wait_parent_failed_raises(self): + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + + with pytest.raises(AirflowException, match="repair budget is exhausted"): + operator.execute_complete_after_repair_wait( + context=None, + event={ + "status": "parent_failed", + "parent_run_id": self.PARENT_RUN_ID, + "databricks_task_key": operator.databricks_task_key, + "parent_run_state": RunState( + life_cycle_state="TERMINATED", + state_message=None, + result_state="FAILED", + ).to_json(), + "run_page_url": "https://example", + }, + ) + + @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") + def test_sync_wait_for_new_sub_run_attempt_returns_new_attempt(self, mock_sleep): + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + hook = MagicMock() + operator.__dict__["_hook"] = hook + operator.databricks_run_id = self.PARENT_RUN_ID + hook.get_run.side_effect = [ + { + "state": {"life_cycle_state": "RUNNING", "result_state": None, "state_message": None}, + "tasks": [ + { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + }, + ], + }, + { + "state": {"life_cycle_state": "RUNNING", "result_state": None, "state_message": None}, + "tasks": [ + { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + }, + { + "run_id": self.NEW_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 2000, + }, + ], + }, + ] + + result = operator._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + original_start_time=1000, + tg=operator.task_group, + ) + + assert result == self.NEW_SUB_RUN_ID + mock_sleep.assert_called_once_with(15) + + @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") + def test_sync_wait_for_new_sub_run_attempt_returns_none_after_repair_timeout(self, mock_sleep): + # Gives up only once the deadline (end_time + timeout) passes. 120s timeout, 15s polling = 8 sleeps. + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + hook = MagicMock() + operator.__dict__["_hook"] = hook + operator.databricks_run_id = self.PARENT_RUN_ID + original_task = { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + } + terminal_state = {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": None} + hook.get_run.return_value = { + "state": terminal_state, + "tasks": [original_task], + "end_time": self.PARENT_END_TIME_MS, + } + + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: + mock_sleep.side_effect = lambda _seconds: traveller.move_to(time.time() + 15) + result = operator._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + original_start_time=1000, + tg=operator.task_group, + ) + + assert result is None + # 8 sleeps advance the clock to the deadline; the 9th poll observes it has passed. + assert hook.get_run.call_count == 9 + assert mock_sleep.call_count == 8 + + @mock.patch("airflow.providers.databricks.operators.databricks.time.sleep") + def test_sync_wait_deadline_resets_when_run_leaves_terminal(self, mock_sleep): + # When the repair takes effect the run leaves terminal, resetting the deadline, and the + # waiter picks up the repaired attempt instead of returning None. + operator = self._operator_with_workflow_tg(workflow_repair_attempts=2) + hook = MagicMock() + operator.__dict__["_hook"] = hook + operator.databricks_run_id = self.PARENT_RUN_ID + original_task = { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 1000, + } + new_task = { + "run_id": self.NEW_SUB_RUN_ID, + "task_key": operator.databricks_task_key, + "start_time": 2000, + } + terminal_state = {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": None} + running_state = {"life_cycle_state": "RUNNING", "result_state": None, "state_message": None} + hook.get_run.side_effect = [ + {"state": terminal_state, "tasks": [original_task], "end_time": self.PARENT_END_TIME_MS}, + {"state": terminal_state, "tasks": [original_task], "end_time": self.PARENT_END_TIME_MS}, + # Repair lands: the run is RUNNING again → deadline resets. + {"state": running_state, "tasks": [original_task]}, + # The repaired attempt for the watched task_key appears. + {"state": running_state, "tasks": [original_task, new_task]}, + ] + + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: + mock_sleep.side_effect = lambda _seconds: traveller.move_to(time.time() + 15) + result = operator._sync_wait_for_new_sub_run_attempt( + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + original_start_time=1000, + tg=operator.task_group, + ) + + assert result == self.NEW_SUB_RUN_ID + assert hook.get_run.call_count == 4 diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 84069ee0ff7f8..484dd42ffe7ac 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -29,15 +29,17 @@ from airflow import DAG from airflow.models.baseoperator import BaseOperator -from airflow.providers.common.compat.sdk import AirflowException, timezone -from airflow.providers.databricks.hooks.databricks import RunLifeCycleState +from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred, timezone +from airflow.providers.databricks.hooks.databricks import RunLifeCycleState, RunState from airflow.providers.databricks.operators.databricks import DatabricksNotebookOperator from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, WorkflowRunMetadata, _CreateDatabricksWorkflowOperator, + _DatabricksWorkflowRepairCoordinatorOperator, _flatten_node, ) +from airflow.providers.databricks.triggers.databricks import DatabricksWorkflowRepairCoordinatorTrigger from airflow.providers.standard.operators.empty import EmptyOperator DEFAULT_DATE = timezone.datetime(2021, 1, 1) @@ -571,3 +573,189 @@ def test_reset_job_payload_carries_parent_depends_on(self, mock_databricks_hook) job_id, job_spec = launch_task._hook.reset_job.call_args.args assert job_id == 42 self._assert_parent_depends_on(job_spec) + + +class TestDatabricksWorkflowRepairCoordinatorOperator: + LAUNCH_TASK_ID = "wf.launch" + LAUNCH_RETURN = {"conn_id": "databricks_default", "job_id": 42, "run_id": 100} + + def _make_operator( + self, + workflow_repair_attempts: int = 2, + deferrable: bool = True, + ) -> _DatabricksWorkflowRepairCoordinatorOperator: + return _DatabricksWorkflowRepairCoordinatorOperator( + task_id="repair_coordinator", + databricks_conn_id="databricks_default", + launch_task_id=self.LAUNCH_TASK_ID, + workflow_repair_attempts=workflow_repair_attempts, + workflow_repair_polling_period=10, + deferrable=deferrable, + ) + + def test_execute_raises_when_launch_xcom_missing(self): + operator = self._make_operator() + ctx = {"ti": MagicMock()} + ctx["ti"].xcom_pull.return_value = None + + with pytest.raises(AirflowException, match="did not publish workflow run metadata"): + operator.execute(ctx) + + def test_execute_defers_on_coordinator_trigger(self): + operator = self._make_operator(workflow_repair_attempts=3) + ctx = {"ti": MagicMock()} + ctx["ti"].xcom_pull.return_value = self.LAUNCH_RETURN + + with pytest.raises(TaskDeferred) as exc: + operator.execute(ctx) + + ctx["ti"].xcom_push.assert_not_called() + assert exc.value.method_name == "execute_complete" + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairCoordinatorTrigger) + assert trigger.run_id == self.LAUNCH_RETURN["run_id"] + assert trigger.workflow_repair_attempts == 3 + assert trigger.repair_attempts == 0 + assert trigger.latest_repair_id is None + assert trigger.polling_period_seconds == 10 + + def test_execute_complete_repaired_redefers_without_xcom_push(self): + operator = self._make_operator(workflow_repair_attempts=3) + ctx = {"ti": MagicMock()} + + with pytest.raises(TaskDeferred) as exc: + operator.execute_complete( + ctx, + event={ + "status": "repaired", + "run_id": 100, + "repair_attempts": 1, + "latest_repair_id": 555, + }, + ) + + ctx["ti"].xcom_push.assert_not_called() + trigger = exc.value.trigger + assert isinstance(trigger, DatabricksWorkflowRepairCoordinatorTrigger) + assert trigger.run_id == 100 + assert trigger.repair_attempts == 1 + assert trigger.latest_repair_id == 555 + assert trigger.workflow_repair_attempts == 3 + + def test_execute_complete_failed_raises_with_errors_in_message(self): + operator = self._make_operator(workflow_repair_attempts=2) + ctx = {"ti": MagicMock()} + errors = [{"task_key": "t1", "run_id": 11, "error": "boom"}] + + with pytest.raises(AirflowException) as exc: + operator.execute_complete( + ctx, + event={ + "status": "failed", + "run_id": 100, + "repair_attempts": 2, + "latest_repair_id": 999, + "errors": errors, + }, + ) + + message = str(exc.value) + assert "100" in message + assert "workflow_repair_attempts=2" in message + assert "boom" in message + ctx["ti"].xcom_push.assert_not_called() + + @patch("airflow.providers.databricks.operators.databricks_workflow.time.sleep") + def test_sync_run_repairs_failed_run_and_returns_success(self, mock_sleep): + operator = self._make_operator(workflow_repair_attempts=2, deferrable=False) + hook = MagicMock() + operator.__dict__["_hook"] = hook + # 1. top of outer loop: terminal+failed → repair_run + # 2. reflection poll: run still terminal, but repair_id is in repair_history → reflected → + # break reflection loop (the fast-repair case a terminal-state-only check would miss) + # 3. top of outer loop: terminal+success → return + hook.get_run_state.side_effect = [ + RunState("TERMINATED", "FAILED", ""), + RunState("TERMINATED", "SUCCESS", ""), + ] + hook.get_run.side_effect = [ + { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, + "tasks": [], + "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, + }, + { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, + "repair_history": [{"id": 555}], + }, + ] + hook.repair_run.return_value = 555 + + result = operator._run_sync(run_id=100) + + hook.repair_run.assert_called_once_with( + { + "run_id": 100, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, + } + ) + assert result == {"run_id": 100, "repair_attempts": 1, "latest_repair_id": 555} + # Grace loop slept once before observing non-terminal state. + assert mock_sleep.call_count == 1 + + @patch("airflow.providers.databricks.operators.databricks_workflow.time.sleep") + def test_sync_run_raises_when_repair_not_reflected_within_timeout(self, mock_sleep): + operator = self._make_operator(workflow_repair_attempts=2, deferrable=False) + operator.workflow_repair_timeout = 0 + hook = MagicMock() + operator.__dict__["_hook"] = hook + # terminal+failed → repair_run; reflection poll still terminal with no repair_history entry + # → deadline trips → raise. + # timeout=0 makes the deadline the run's end_time, already past, so the loop bails out. + hook.get_run_state.side_effect = [ + RunState("TERMINATED", "FAILED", ""), + ] + hook.get_run.return_value = { + "state": {"life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": ""}, + "tasks": [], + "overriding_parameters": {}, + } + hook.repair_run.return_value = 555 + + with pytest.raises(AirflowException) as exc: + operator._run_sync(run_id=100) + + message = str(exc.value) + assert "did not reflect repair_id=555" in message + assert "run 100" in message + assert "workflow_repair_timeout" in message + # Only the original repair_run — the raise must prevent a duplicate. + hook.repair_run.assert_called_once() + # One reflection-loop sleep fired before the deadline check tripped. + assert mock_sleep.call_count == 1 + + +class TestDatabricksWorkflowTaskGroupCoordinatorInjection: + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Coordinator task is only injected on Airflow 3+") + def test_workflow_repair_attempts_positive_injects_coordinator_with_launch_upstream(self): + with DAG(dag_id="dwf_with_coord", schedule=None, start_date=DEFAULT_DATE): + with DatabricksWorkflowTaskGroup( + group_id="wf", + databricks_conn_id="databricks_conn", + workflow_repair_attempts=2, + workflow_repair_polling_period=15, + workflow_repair_timeout=120, + ) as tg: + task = MagicMock(task_id="task1") + task._convert_to_databricks_workflow_task = MagicMock(return_value={}) + tg.add(task) + + coordinator = tg.children["wf.repair_coordinator"] + assert isinstance(coordinator, _DatabricksWorkflowRepairCoordinatorOperator) + assert coordinator.workflow_repair_attempts == 2 + assert coordinator.workflow_repair_polling_period == 15 + assert coordinator.workflow_repair_timeout == 120 + assert coordinator.launch_task_id == "wf.launch" + assert "wf.launch" in coordinator.upstream_task_ids diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index 8854eb03fb5bc..2dacb611c0f80 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -17,9 +17,12 @@ # under the License. from __future__ import annotations +import time +from typing import Any from unittest import mock import pytest +import time_machine from tenacity import stop_after_attempt, wait_incrementing from airflow.models import Connection @@ -27,6 +30,8 @@ from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, DatabricksSQLStatementExecutionTrigger, + DatabricksWorkflowRepairCoordinatorTrigger, + DatabricksWorkflowRepairWaitTrigger, ) from airflow.triggers.base import TriggerEvent @@ -228,6 +233,7 @@ async def test_run_return_success( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS" ).to_json(), "run_page_url": RUN_PAGE_URL, + "run_start_time": None, "repair_run": False, "errors": [], } @@ -259,6 +265,7 @@ async def test_run_return_failure( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED" ).to_json(), "run_page_url": RUN_PAGE_URL, + "run_start_time": None, "repair_run": False, "errors": [ {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, @@ -299,6 +306,7 @@ async def test_sleep_between_retries( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS" ).to_json(), "run_page_url": RUN_PAGE_URL, + "run_start_time": None, "repair_run": False, "errors": [], } @@ -432,3 +440,370 @@ async def test_sleep_between_retries(self, mock_a_get_sql_statement_state, mock_ async def test_on_kill_cancels_statement(self, mock_cancel_sql_statement): await self.trigger.on_kill() mock_cancel_sql_statement.assert_called_once_with(STATEMENT_ID) + + +class TestDatabricksWorkflowRepairCoordinatorTrigger: + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + login=LOGIN, + password=PASSWORD, + extra=None, + ) + ) + + def _make_trigger( + self, + workflow_repair_attempts: int = 2, + repair_attempts: int = 0, + latest_repair_id: int | None = None, + workflow_repair_timeout: int = 300, + ) -> DatabricksWorkflowRepairCoordinatorTrigger: + return DatabricksWorkflowRepairCoordinatorTrigger( + run_id=RUN_ID, + databricks_conn_id=DEFAULT_CONN_ID, + workflow_repair_attempts=workflow_repair_attempts, + repair_attempts=repair_attempts, + latest_repair_id=latest_repair_id, + polling_period_seconds=POLLING_INTERVAL_SECONDS, + workflow_repair_timeout=workflow_repair_timeout, + run_page_url=RUN_PAGE_URL, + ) + + def test_serialize_round_trips_state(self): + trigger = self._make_trigger(workflow_repair_attempts=3, repair_attempts=1, latest_repair_id=42) + + path, kwargs = trigger.serialize() + restored = DatabricksWorkflowRepairCoordinatorTrigger(**kwargs) + + assert ( + path + == "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairCoordinatorTrigger" + ) + assert restored.serialize() == (path, kwargs) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_emits_completed_when_run_succeeds(self, mock_get_run_state, mock_get_run): + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="SUCCESS", + ) + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED + + trigger = self._make_trigger(workflow_repair_attempts=2, repair_attempts=0, latest_repair_id=None) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "completed" + assert events[0].payload["run_id"] == RUN_ID + assert events[0].payload["repair_attempts"] == 0 + assert events[0].payload["latest_repair_id"] is None + assert events[0].payload["errors"] == [] + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_first_failure_within_budget_calls_repair_and_emits_repaired( + self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run, mock_sleep + ): + # Outer loop: terminal+failed → trigger issues repair. + # Reflection poll: run is still terminal, but repair_id is in repair_history → reflected → + # reflection loop breaks. This is the fast-repair case that a terminal-state-only check + # would miss. + mock_get_run_state.side_effect = [ + RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ), + ] + mock_get_run.side_effect = [ + GET_RUN_RESPONSE_TERMINATED_WITH_FAILED, + {**GET_RUN_RESPONSE_TERMINATED_WITH_FAILED, "repair_history": [{"id": 101}]}, + ] + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + mock_repair_run.return_value = 101 + + trigger = self._make_trigger(workflow_repair_attempts=2, repair_attempts=0, latest_repair_id=None) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "repaired" + assert events[0].payload["run_id"] == RUN_ID + assert events[0].payload["repair_attempts"] == 1 + assert events[0].payload["latest_repair_id"] == 101 + assert events[0].payload["errors"] == [ + {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, + {"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE}, + ] + + mock_repair_run.assert_called_once() + repair_json = mock_repair_run.call_args.args[0] + assert repair_json["run_id"] == RUN_ID + assert repair_json["rerun_all_failed_tasks"] is True + # First repair: latest_repair_id was None, so the field must be omitted from the payload + assert "latest_repair_id" not in repair_json + # Grace loop slept once before observing non-terminal state. + assert mock_sleep.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_emits_repair_not_reflected_when_reflection_timeout_elapses( + self, + mock_get_run_state, + mock_get_run, + mock_get_run_output, + mock_repair_run, + mock_sleep, + ): + # terminal+failed → repair; reflection poll still terminal with no repair_history entry → + # deadline trips → repair_not_reflected. + # timeout=0 makes the deadline the run's end_time, already past, so the loop bails out. + mock_get_run_state.side_effect = [ + RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ), + ] + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + mock_repair_run.return_value = 101 + + trigger = self._make_trigger( + workflow_repair_attempts=2, + repair_attempts=0, + latest_repair_id=None, + workflow_repair_timeout=0, + ) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "repair_not_reflected" + assert events[0].payload["run_id"] == RUN_ID + assert events[0].payload["repair_attempts"] == 1 + assert events[0].payload["latest_repair_id"] == 101 + # Only the original repair_run — yielding must prevent a duplicate next cycle. + mock_repair_run.assert_called_once() + # One reflection-loop sleep fired before the deadline check tripped. + assert mock_sleep.call_count == 1 + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.repair_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_emits_failed_when_budget_exhausted( + self, mock_get_run_state, mock_get_run, mock_get_run_output, mock_repair_run + ): + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ) + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + + trigger = self._make_trigger(workflow_repair_attempts=2, repair_attempts=2, latest_repair_id=202) + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload["status"] == "failed" + assert events[0].payload["repair_attempts"] == 2 + assert events[0].payload["latest_repair_id"] == 202 + assert events[0].payload["errors"] == [ + {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, + {"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE}, + ] + mock_repair_run.assert_not_called() + + +class TestDatabricksWorkflowRepairWaitTrigger: + PARENT_RUN_ID = 100 + TASK_KEY = "monitored_task" + ORIGINAL_SUB_RUN_ID = 500 + NEW_SUB_RUN_ID = 700 + # Parent run terminal end_time (epoch ms) the waiter anchors its give-up deadline to. + PARENT_END_TIME_MS = 1_700_000_000_000 + ANCHOR_S = PARENT_END_TIME_MS / 1000 + + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + login=LOGIN, + password=PASSWORD, + extra=None, + ) + ) + + def _make_trigger(self, workflow_repair_timeout: int = 180) -> DatabricksWorkflowRepairWaitTrigger: + return DatabricksWorkflowRepairWaitTrigger( + run_id=self.PARENT_RUN_ID, + databricks_conn_id=DEFAULT_CONN_ID, + databricks_task_key=self.TASK_KEY, + original_sub_run_id=self.ORIGINAL_SUB_RUN_ID, + polling_period_seconds=POLLING_INTERVAL_SECONDS, + workflow_repair_timeout=workflow_repair_timeout, + run_page_url=RUN_PAGE_URL, + ) + + def _run_payload( + self, + result_state: str | None, + life_cycle_state: str = LIFE_CYCLE_STATE_TERMINATED, + tasks: list[dict] | None = None, + end_time: int | None = None, + ) -> dict: + payload: dict[str, Any] = { + "run_page_url": RUN_PAGE_URL, + "state": { + "life_cycle_state": life_cycle_state, + "state_message": None, + "result_state": result_state, + }, + "tasks": tasks or [], + } + if end_time is not None: + payload["end_time"] = end_time + return payload + + def test_serialize_round_trips_state(self): + trigger = self._make_trigger() + + path, kwargs = trigger.serialize() + restored = DatabricksWorkflowRepairWaitTrigger(**kwargs) + + assert path == "airflow.providers.databricks.triggers.databricks.DatabricksWorkflowRepairWaitTrigger" + assert restored.serialize() == (path, kwargs) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_emits_new_attempt_when_new_sub_run_appears(self, mock_get_run): + mock_get_run.return_value = self._run_payload( + result_state=None, + life_cycle_state="RUNNING", + tasks=[ + { + "run_id": self.ORIGINAL_SUB_RUN_ID, + "task_key": self.TASK_KEY, + "start_time": 1000, + }, + { + "run_id": self.NEW_SUB_RUN_ID, + "task_key": self.TASK_KEY, + "start_time": 2000, + }, + ], + ) + + trigger = self._make_trigger() + events = [event async for event in trigger.run()] + + assert len(events) == 1 + assert events[0].payload == { + "status": "new_attempt", + "parent_run_id": self.PARENT_RUN_ID, + "databricks_task_key": self.TASK_KEY, + "new_sub_run_id": self.NEW_SUB_RUN_ID, + "run_page_url": RUN_PAGE_URL, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_emits_parent_failed_after_repair_timeout(self, mock_get_run): + # Gives up only once the deadline (end_time + timeout) passes, not after a fixed poll + # count. 180s timeout, 30s polling = 6 sleeps. + mock_get_run.return_value = self._run_payload( + result_state="FAILED", + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + tasks=[{"run_id": self.ORIGINAL_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 1000}], + end_time=self.PARENT_END_TIME_MS, + ) + + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: + + def advance(_seconds): + traveller.move_to(time.time() + POLLING_INTERVAL_SECONDS) + + with mock.patch("asyncio.sleep", side_effect=advance) as mock_sleep: + trigger = self._make_trigger(workflow_repair_timeout=180) + events = [event async for event in trigger.run()] + + # 6 sleeps advance the clock to the deadline; parent_failed fires on the 7th poll. + assert mock_get_run.call_count == 7 + assert mock_sleep.call_count == 6 + assert len(events) == 1 + payload = events[0].payload + assert payload["status"] == "parent_failed" + assert payload["parent_run_id"] == self.PARENT_RUN_ID + assert payload["databricks_task_key"] == self.TASK_KEY + assert RunState.from_json(payload["parent_run_state"]).result_state == "FAILED" + + @pytest.mark.asyncio + @mock.patch("asyncio.sleep") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_parent_failed_fires_immediately_when_deadline_already_past(self, mock_get_run, mock_sleep): + # Deadline is anchored to end_time, not to when polling started: a run that went terminal + # long ago is already past its deadline, so the first poll fails without waiting. + mock_get_run.return_value = self._run_payload( + result_state="FAILED", + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + tasks=[{"run_id": self.ORIGINAL_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 1000}], + end_time=self.PARENT_END_TIME_MS, + ) + + with time_machine.travel(self.ANCHOR_S + 1000, tick=False): + trigger = self._make_trigger(workflow_repair_timeout=180) + events = [event async for event in trigger.run()] + + assert mock_get_run.call_count == 1 + assert mock_sleep.call_count == 0 + assert events[0].payload["status"] == "parent_failed" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + async def test_deadline_resets_when_run_leaves_terminal(self, mock_get_run): + # When the repair takes effect the run leaves terminal, resetting the deadline, and the + # waiter picks up the repaired attempt instead of failing. + original_task = {"run_id": self.ORIGINAL_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 1000} + new_task = {"run_id": self.NEW_SUB_RUN_ID, "task_key": self.TASK_KEY, "start_time": 2000} + mock_get_run.side_effect = [ + self._run_payload(result_state="FAILED", tasks=[original_task], end_time=self.PARENT_END_TIME_MS), + self._run_payload(result_state="FAILED", tasks=[original_task], end_time=self.PARENT_END_TIME_MS), + # Repair lands: the run is RUNNING again → deadline resets. + self._run_payload(result_state=None, life_cycle_state="RUNNING", tasks=[original_task]), + # The repaired attempt for the watched task_key appears. + self._run_payload(result_state=None, life_cycle_state="RUNNING", tasks=[original_task, new_task]), + ] + + with time_machine.travel(self.ANCHOR_S, tick=False) as traveller: + + def advance(_seconds): + traveller.move_to(time.time() + POLLING_INTERVAL_SECONDS) + + with mock.patch("asyncio.sleep", side_effect=advance): + trigger = self._make_trigger(workflow_repair_timeout=180) + events = [event async for event in trigger.run()] + + assert mock_get_run.call_count == 4 + assert len(events) == 1 + assert events[0].payload["status"] == "new_attempt" + assert events[0].payload["new_sub_run_id"] == self.NEW_SUB_RUN_ID diff --git a/providers/databricks/tests/unit/databricks/utils/test_databricks.py b/providers/databricks/tests/unit/databricks/utils/test_databricks.py index 33716879713ed..5cdb9c8aa988e 100644 --- a/providers/databricks/tests/unit/databricks/utils/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/utils/test_databricks.py @@ -25,8 +25,12 @@ from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.databricks.hooks.databricks import RunState from airflow.providers.databricks.utils.databricks import ( + build_repair_run_json, + compute_repair_deadline, extract_failed_task_errors, extract_failed_task_errors_async, + find_new_workflow_task_attempt, + is_repair_reflected, normalise_json_content, validate_trigger_event, ) @@ -95,6 +99,78 @@ def test_validate_trigger_event_failure(self): with pytest.raises(AirflowException): validate_trigger_event(event) + def test_find_new_workflow_task_attempt_picks_newest_matching_attempt(self): + tasks = [ + {"run_id": TASK_RUN_ID_1, "task_key": TASK_KEY_1, "start_time": 1000}, + {"run_id": 201, "task_key": TASK_KEY_1, "start_time": 1500}, + {"run_id": 202, "task_key": TASK_KEY_1, "start_time": 2500}, + {"run_id": 203, "task_key": TASK_KEY_2, "start_time": 3000}, + ] + + result = find_new_workflow_task_attempt( + tasks=tasks, + task_key=TASK_KEY_1, + original_sub_run_id=TASK_RUN_ID_1, + original_start_time=1000, + ) + + assert result == {"run_id": 202, "task_key": TASK_KEY_1, "start_time": 2500} + + def test_find_new_workflow_task_attempt_skips_not_yet_started_attempt(self): + # A repaired attempt that has not started yet (missing/0/None start_time) is not selected; + # the caller keeps polling until it starts. + tasks = [ + {"run_id": TASK_RUN_ID_1, "task_key": TASK_KEY_1, "start_time": 1000}, + {"run_id": 201, "task_key": TASK_KEY_1, "start_time": 0}, + {"run_id": 202, "task_key": TASK_KEY_1, "start_time": None}, + {"run_id": 203, "task_key": TASK_KEY_1}, + ] + + result = find_new_workflow_task_attempt( + tasks=tasks, + task_key=TASK_KEY_1, + original_sub_run_id=TASK_RUN_ID_1, + original_start_time=None, + ) + + assert result is None + + def test_build_repair_run_json_includes_optional_fields_only_when_present(self): + assert build_repair_run_json(run_id=RUN_ID, latest_repair_id=None) == { + "run_id": RUN_ID, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + } + + assert build_repair_run_json( + run_id=RUN_ID, + latest_repair_id=42, + overriding_parameters={"notebook_params": {"date": "2024-01-01"}}, + ) == { + "run_id": RUN_ID, + "rerun_all_failed_tasks": True, + "rerun_dependent_tasks": True, + "latest_repair_id": 42, + "overriding_parameters": {"notebook_params": {"date": "2024-01-01"}}, + } + + def test_compute_repair_deadline_anchors_to_terminal_end_time(self): + # end_time (epoch ms) in seconds plus the timeout, independent of when polling started. + assert compute_repair_deadline({"end_time": 1_700_000_000_000}, 180) == 1_700_000_000.0 + 180 + + def test_is_repair_reflected_detects_repair_id_in_history(self): + run_info = {"repair_history": [{"id": 111}, {"id": 222}]} + # The repair is reflected as soon as its repair_id is in repair_history, even while the run + # is still terminal — a fast repair never has to be caught mid-flight. + assert is_repair_reflected(run_info, 222) is True + assert is_repair_reflected(run_info, 111) is True + + def test_is_repair_reflected_false_when_not_present_or_unknown(self): + # repair_id absent from history, history missing entirely, or no repair_id yet. + assert is_repair_reflected({"repair_history": [{"id": 111}]}, 999) is False + assert is_repair_reflected({}, 111) is False + assert is_repair_reflected({"repair_history": [{"id": 111}]}, None) is False + class TestExtractFailedTaskErrors: """Test cases for the extract_failed_task_errors utility function (synchronous version)"""