Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions providers/databricks/docs/operators/workflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,29 @@ To minimize update conflicts, we recommend that you keep parameters in the ``not
``DatabricksWorkflowTaskGroup`` and not in the ``DatabricksNotebookOperator`` whenever possible.
This is because, tasks in the ``DatabricksWorkflowTaskGroup`` are passed in on the job trigger time and
do not modify the job definition.

Automatic repair (Airflow 3+)
-----------------------------

Set ``workflow_repair_attempts=N`` to auto-repair a failed workflow run up to N times. The task
group injects a ``repair_coordinator`` sibling that waits for the run to terminate and then calls
Databricks ``repair_run`` with ``rerun_all_failed_tasks=True``. Default is ``0`` (off).

.. code-block:: python

task_group = DatabricksWorkflowTaskGroup(
group_id="Example Workflow",
databricks_conn_id="databricks_default",
workflow_repair_attempts=2,
workflow_repair_polling_period=15,
workflow_repair_timeout=300,
)

After ``repair_run`` is accepted, Databricks needs a moment to drop the parent run out of its
terminal state. The coordinator polls every ``workflow_repair_polling_period`` seconds and gives
Databricks up to ``workflow_repair_timeout`` seconds (default 300s / 5 minutes) to

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default here is described as 300s / 5 minutes, but the code default for workflow_repair_timeout is 180 (3 minutes) — see _DatabricksWorkflowRepairCoordinatorOperator.__init__ (workflow_repair_timeout: int = 180). One of them should be corrected so the docs match the code.


Drafted-by: Claude Code (Opus 4.8); reviewed by @moomindani before posting

reflect the repair before it fails the coordinator. Raise the timeout if your workspace is slow
to surface repaired runs.

Downstream task monitors stay in the same Airflow attempt across repairs, so size
``execution_timeout`` for the original run plus all repairs and set ``retries=0`` on workflow tasks.
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ class DatabricksSqlExecutionError(AirflowException):

class DatabricksSqlExecutionTimeout(DatabricksSqlExecutionError):
"""Raised when a sql execution times out."""


class DatabricksWorkflowRepairError(AirflowException):
"""Raised when Databricks Workflow repair coordination fails."""
Original file line number Diff line number Diff line change
Expand Up @@ -560,25 +560,31 @@ def get_run_tasks(self, run_id: int) -> list[dict[str, Any]]:

return all_tasks

def get_run(self, run_id: int) -> dict[str, Any]:
def get_run(self, run_id: int, include_history: bool = False) -> dict[str, Any]:
"""
Retrieve run information.

:param run_id: id of the run
:param include_history: whether to include the run's ``repair_history`` in the response.
:return: state of the run
"""
json = {"run_id": run_id}
json: dict[str, Any] = {"run_id": run_id}
if include_history:
json["include_history"] = "true"
response = self._do_api_call(GET_RUN_ENDPOINT, json)
return response

async def a_get_run(self, run_id: int) -> dict[str, Any]:
async def a_get_run(self, run_id: int, include_history: bool = False) -> dict[str, Any]:
"""
Async version of `get_run`.

:param run_id: id of the run
:param include_history: whether to include the run's ``repair_history`` in the response.
:return: state of the run
"""
json = {"run_id": run_id}
json: dict[str, Any] = {"run_id": run_id}
if include_history:
json["include_history"] = "true"
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
return response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import TYPE_CHECKING, Any

from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, BaseOperatorLink, XCom, conf
from airflow.providers.databricks.exceptions import DatabricksWorkflowRepairError
from airflow.providers.databricks.hooks.databricks import (
DatabricksHook,
RunLifeCycleState,
Expand All @@ -43,9 +44,12 @@
)
from airflow.providers.databricks.triggers.databricks import (
DatabricksExecutionTrigger,
DatabricksWorkflowRepairWaitTrigger,
)
from airflow.providers.databricks.utils.databricks import (
compute_repair_deadline,
extract_failed_task_errors,
find_new_workflow_task_attempt,
normalise_json_content,
validate_trigger_event,
)
Expand Down Expand Up @@ -1491,19 +1495,64 @@ def monitor_databricks_job(self) -> None:
self.databricks_task_key,
run_state.life_cycle_state,
)
if self.deferrable and not run_state.is_terminal:
self.defer(
trigger=DatabricksExecutionTrigger(
run_id=current_task_run_id,
databricks_conn_id=self.databricks_conn_id,
polling_period_seconds=self.polling_period_seconds,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=self.caller,
),
method_name=DEFER_METHOD_NAME,
)
if self.deferrable:
if not run_state.is_terminal:
self.defer(
trigger=DatabricksExecutionTrigger(
run_id=current_task_run_id,
databricks_conn_id=self.databricks_conn_id,
polling_period_seconds=self.polling_period_seconds,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=self.caller,
),
method_name=DEFER_METHOD_NAME,
)
elif not run_state.is_successful:
tg = self._workflow_task_group_with_repair()
if tg is not None:
self._defer_to_workflow_repair_wait(
original_sub_run_id=current_task_run_id,
original_start_time=run.get("start_time"),
tg=tg,
)
else:
tg = self._workflow_task_group_with_repair()
if tg is not None:
while True:
while not run_state.is_terminal:
time.sleep(self.polling_period_seconds)
run = self._hook.get_run(current_task_run_id)
run_state = RunState(**run["state"])
self.log.info(
"Current state of the databricks task %s is %s",
self.databricks_task_key,
run_state.life_cycle_state,
)
if run_state.is_successful:
break
new_sub_run_id = self._sync_wait_for_new_sub_run_attempt(
original_sub_run_id=current_task_run_id,
original_start_time=run.get("start_time"),
tg=tg,
)
if new_sub_run_id is None:
break
self.log.info(
"Repair coordinator produced a new attempt for task_key %s (sub-run %s).",
self.databricks_task_key,
new_sub_run_id,
)
current_task_run_id = new_sub_run_id
run = self._hook.get_run(current_task_run_id)
run_state = RunState(**run["state"])
self.log.info(
"Current state of the databricks task %s is %s",
self.databricks_task_key,
run_state.life_cycle_state,
)

while not run_state.is_terminal:
time.sleep(self.polling_period_seconds)
run = self._hook.get_run(current_task_run_id)
Expand All @@ -1523,13 +1572,7 @@ def monitor_databricks_job(self) -> None:
def execute(self, context: Context) -> None:
"""Execute the operator. Launch the job and monitor it if wait_for_termination is set to True."""
if self._databricks_workflow_task_group:
# If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched.
if not self.workflow_run_metadata:
launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch"))
self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id)
workflow_run_metadata = WorkflowRunMetadata(**self.workflow_run_metadata)
self.databricks_run_id = workflow_run_metadata.run_id
self.databricks_conn_id = workflow_run_metadata.conn_id
workflow_run_metadata = self._resolve_workflow_run_metadata(context)

# Store operator links in XCom for Airflow 3 compatibility
if AIRFLOW_V_3_0_PLUS:
Expand All @@ -1544,11 +1587,158 @@ def execute(self, context: Context) -> None:
if self.wait_for_termination:
self.monitor_databricks_job()

def _resolve_workflow_run_metadata(self, context: Context | dict | None) -> WorkflowRunMetadata:
# Called from both execute() and execute_complete(): the deferrable resume path
# re-instantiates the operator, so attributes set in execute() are lost and we
# re-resolve from the templated workflow_run_metadata field.
if not self.workflow_run_metadata and context is not None:
launch_task_id = next((task for task in self.upstream_task_ids if task.endswith(".launch")), None)
ti = context.get("ti") if isinstance(context, dict) else context["ti"]
if launch_task_id is not None and ti is not None:
self.workflow_run_metadata = ti.xcom_pull(task_ids=launch_task_id)
if not self.workflow_run_metadata:
raise ValueError("workflow_run_metadata is required to resolve the parent workflow run")
workflow_run_metadata = WorkflowRunMetadata(**self.workflow_run_metadata)
self.databricks_run_id = workflow_run_metadata.run_id
self.databricks_conn_id = workflow_run_metadata.conn_id
return workflow_run_metadata

def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
errors = event.get("errors", [])

if not run_state.is_successful:
tg = self._workflow_task_group_with_repair()
if tg is not None:
self._resolve_workflow_run_metadata(context)
self._defer_to_workflow_repair_wait(
original_sub_run_id=event["run_id"],
original_start_time=event.get("run_start_time"),
tg=tg,
)

self._handle_terminal_run_state(run_state, errors)

def _workflow_task_group_with_repair(self) -> DatabricksWorkflowTaskGroup | None:
if not AIRFLOW_V_3_0_PLUS:
return None
tg = self._databricks_workflow_task_group
if tg is None or getattr(tg, "workflow_repair_attempts", 0) <= 0:
return None
return tg

def _sync_wait_for_new_sub_run_attempt(
self,
original_sub_run_id: int,
original_start_time: int | None,
tg: DatabricksWorkflowTaskGroup,
) -> int | None:
"""Sync mirror of DatabricksWorkflowRepairWaitTrigger. Returns the new sub-run id or None."""
self.log.info(
"Sub-run %s for task_key %s reached terminal failure; waiting for a repair "
"attempt issued by the repair coordinator.",
original_sub_run_id,
self.databricks_task_key,
)
polling_period_seconds = tg.workflow_repair_polling_period
workflow_repair_timeout = tg.workflow_repair_timeout
# Give-up deadline (epoch seconds), anchored to the run's end_time to match the
# coordinator. Reset when the run leaves terminal failure so a later failure restarts it.
repair_deadline: float | None = None
while True:
run_info = self._hook.get_run(self.databricks_run_id) # type: ignore[arg-type]
parent_run_state = RunState(**run_info["state"])
tasks = run_info.get("tasks", [])
new_attempt = find_new_workflow_task_attempt(
tasks=tasks,
task_key=self.databricks_task_key,
original_sub_run_id=original_sub_run_id,
original_start_time=original_start_time,
)
if new_attempt is not None:
return new_attempt["run_id"]
if parent_run_state.is_terminal and not parent_run_state.is_successful:
if repair_deadline is None:
repair_deadline = compute_repair_deadline(run_info, workflow_repair_timeout)
if time.time() >= repair_deadline:
self.log.info(
"Parent run %s reached terminal state %s without a new attempt for "
"task_key %s within %ss (anchored to the run's terminal end_time).",
self.databricks_run_id,
parent_run_state.result_state,
self.databricks_task_key,
workflow_repair_timeout,
)
return None
else:
# Not in terminal failure — clear the deadline so a later failure restarts it.
repair_deadline = None
time.sleep(polling_period_seconds)

def _defer_to_workflow_repair_wait(
self,
original_sub_run_id: int,
original_start_time: int | None,
tg: DatabricksWorkflowTaskGroup,
) -> None:
self.log.info(
"Sub-run %s for task_key %s reached terminal failure; deferring to wait for a repair "
"attempt issued by the repair coordinator.",
original_sub_run_id,
self.databricks_task_key,
)
self.defer(
trigger=DatabricksWorkflowRepairWaitTrigger(
run_id=self.databricks_run_id, # type: ignore[arg-type]
databricks_conn_id=self.databricks_conn_id,
databricks_task_key=self.databricks_task_key,
original_sub_run_id=original_sub_run_id,
original_start_time=original_start_time,
polling_period_seconds=tg.workflow_repair_polling_period,
workflow_repair_timeout=tg.workflow_repair_timeout,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=self.caller,
),
method_name="execute_complete_after_repair_wait",
)

def execute_complete_after_repair_wait(self, context: dict | None, event: dict) -> None:
status = event.get("status")
if status == "new_attempt":
new_sub_run_id = event["new_sub_run_id"]
self.log.info(
"Repair coordinator produced a new attempt for task_key %s (sub-run %s); "
"deferring on a fresh DatabricksExecutionTrigger to monitor it.",
self.databricks_task_key,
new_sub_run_id,
)
self.defer(
trigger=DatabricksExecutionTrigger(
run_id=new_sub_run_id,
databricks_conn_id=self.databricks_conn_id,
polling_period_seconds=self.polling_period_seconds,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=self.caller,
),
method_name=DEFER_METHOD_NAME,
)
elif status == "parent_failed":
parent_state = RunState.from_json(event["parent_run_state"])
raise DatabricksWorkflowRepairError(
f"Databricks workflow run {event.get('parent_run_id')} reached terminal failure "
f"({parent_state.result_state}) without producing a new attempt for task_key "
f"{self.databricks_task_key!r}; repair budget is exhausted or the coordinator "
f"did not issue a repair."
)
else:
raise DatabricksWorkflowRepairError(
f"DatabricksWorkflowRepairWaitTrigger emitted unexpected status {status!r}: {event}"
)


class DatabricksNotebookOperator(DatabricksTaskBaseOperator):
"""
Expand Down
Loading
Loading