From f5b43ec0077ec561f9889147488d96927b8fb0a5 Mon Sep 17 00:00:00 2001 From: vgkowski Date: Wed, 1 Jul 2026 11:23:29 +0200 Subject: [PATCH] Add EMR Serverless interactive session support to Amazon provider Add operators, sensor and triggers to manage the lifecycle of EMR Serverless interactive (Spark Connect) sessions: start, wait, fetch the Spark Connect endpoint, and terminate. Start/Stop operators and the sensor support deferrable mode. The GetSessionEndpoint operator masks the short-lived auth token. Includes waiters, docs, a system-test example DAG, and unit tests. Requires botocore 1.43.0+ (declared separately). Document botocore/aiobotocore requirements for EMR Serverless sessions Interactive sessions need botocore 1.43.0+ (and aiobotocore 3.6.0+ for deferrable mode) plus EMR 7.13+. Note this in the docs instead of raising the provider's dependency floor, so users who do not use interactive sessions are not forced onto newer AWS SDK versions. --- .../docs/operators/emr/emr_serverless.rst | 66 +++++ .../airflow/providers/amazon/aws/hooks/emr.py | 54 ++++ .../providers/amazon/aws/operators/emr.py | 253 +++++++++++++++++- .../providers/amazon/aws/sensors/emr.py | 79 ++++++ .../providers/amazon/aws/triggers/emr.py | 76 ++++++ .../amazon/aws/waiters/emr-serverless.json | 56 ++++ .../aws/example_emr_serverless_session.py | 126 +++++++++ .../amazon/aws/hooks/test_emr_serverless.py | 67 +++++ .../operators/test_emr_serverless_session.py | 194 ++++++++++++++ .../sensors/test_emr_serverless_session.py | 88 ++++++ .../unit/amazon/aws/triggers/test_emr.py | 51 ++++ .../airflow/providers/common/compat/sdk.py | 5 + 12 files changed, 1114 insertions(+), 1 deletion(-) create mode 100644 providers/amazon/tests/system/amazon/aws/example_emr_serverless_session.py create mode 100644 providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless_session.py create mode 100644 providers/amazon/tests/unit/amazon/aws/sensors/test_emr_serverless_session.py diff --git a/providers/amazon/docs/operators/emr/emr_serverless.rst b/providers/amazon/docs/operators/emr/emr_serverless.rst index 761713c1b08cb..b67ad8d75e9cd 100644 --- a/providers/amazon/docs/operators/emr/emr_serverless.rst +++ b/providers/amazon/docs/operators/emr/emr_serverless.rst @@ -151,6 +151,72 @@ To monitor the state of an EMR Serverless Application you can use :start-after: [START howto_sensor_emr_serverless_application] :end-before: [END howto_sensor_emr_serverless_application] +.. _howto/operator:EmrServerlessStartSessionOperator: + +Start an EMR Serverless interactive session +=========================================== + +To start an EMR Serverless interactive session that a Spark Connect client can attach to, use +:class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessStartSessionOperator`. +Set ``deferrable=True`` to release the worker slot while the session warms up. + +.. note:: + Interactive sessions require Amazon EMR release ``emr-7.13.0`` or later, and the session APIs + (``StartSession``, ``GetSession``, ``GetSessionEndpoint``, ``TerminateSession``) are only + available in ``botocore>=1.43.0``. Deferrable mode additionally needs ``aiobotocore>=3.6.0``, + the first release whose ``botocore`` pin allows 1.43.0. The Amazon provider keeps a lower + minimum for these libraries, so install compatible versions to use interactive sessions. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_emr_serverless_session.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_serverless_start_session] + :end-before: [END howto_operator_emr_serverless_start_session] + +.. _howto/operator:EmrServerlessGetSessionEndpointOperator: + +Get an interactive session endpoint +=================================== + +To resolve the Spark Connect endpoint and a short-lived auth token for a running session, use +:class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessGetSessionEndpointOperator`. +The token expires (about one hour), so run this immediately before connecting. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_emr_serverless_session.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_serverless_get_session_endpoint] + :end-before: [END howto_operator_emr_serverless_get_session_endpoint] + +.. _howto/operator:EmrServerlessStopSessionOperator: + +Stop an EMR Serverless interactive session +========================================== + +To terminate an interactive session, use +:class:`~airflow.providers.amazon.aws.operators.emr.EmrServerlessStopSessionOperator`. +Set ``trigger_rule=ALL_DONE`` so it runs even if a downstream task fails. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_emr_serverless_session.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_emr_serverless_stop_session] + :end-before: [END howto_operator_emr_serverless_stop_session] + +.. _howto/sensor:EmrServerlessSessionSensor: + +Wait on an EMR Serverless interactive session state +==================================================== + +To wait until an interactive session reaches a ready state (``STARTED`` or ``IDLE``), use +:class:`~airflow.providers.amazon.aws.sensors.emr.EmrServerlessSessionSensor`. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_emr_serverless_session.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_emr_serverless_session] + :end-before: [END howto_sensor_emr_serverless_session] + Reference --------- diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py index 87c931d134f40..65100d04660c7 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/emr.py @@ -263,6 +263,10 @@ class EmrServerlessHook(AwsBaseHook): APPLICATION_FAILURE_STATES = {"STOPPED", "TERMINATED"} APPLICATION_SUCCESS_STATES = {"CREATED", "STARTED"} + SESSION_INTERMEDIATE_STATES = {"SUBMITTED", "STARTING"} + SESSION_FAILURE_STATES = {"FAILED", "TERMINATING", "TERMINATED"} + SESSION_SUCCESS_STATES = {"STARTED", "IDLE"} + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["client_type"] = "emr-serverless" super().__init__(*args, **kwargs) @@ -311,6 +315,56 @@ def cancel_running_jobs( return count + def start_session( + self, + application_id: str, + execution_role_arn: str, + name: str | None = None, + idle_timeout_minutes: int | None = None, + configuration_overrides: dict | None = None, + ) -> str: + """ + Start an EMR Serverless interactive session and return its id. + + :param application_id: The id of the EMR Serverless application to run the session on. + :param execution_role_arn: The IAM role ARN the session assumes to access data. + :param name: An optional name for the session. + :param idle_timeout_minutes: Auto-stop the session after this many idle minutes. + :param configuration_overrides: Optional Spark/monitoring configuration overrides. + """ + params: dict[str, Any] = { + "applicationId": application_id, + "executionRoleArn": execution_role_arn, + } + if name is not None: + params["name"] = name + if idle_timeout_minutes is not None: + params["idleTimeoutMinutes"] = idle_timeout_minutes + if configuration_overrides is not None: + params["configurationOverrides"] = configuration_overrides + return self.conn.start_session(**params)["sessionId"] + + def get_session_state(self, application_id: str, session_id: str) -> str: + """Return the current state of an interactive session.""" + return self.conn.get_session(applicationId=application_id, sessionId=session_id)["session"]["state"] + + def get_session_endpoint(self, application_id: str, session_id: str) -> dict: + """ + Return the raw ``GetSessionEndpoint`` boto3 response for a session. + + The response includes the Spark Connect ``endpoint`` URL, a short-lived ``authToken`` + (valid for about one hour), and its ``authTokenExpiresAt`` timestamp. Callers should + fetch it immediately before connecting rather than caching it. + + .. seealso:: + - :external+boto3:py:meth:`EMRServerless.Client.get_session_endpoint` + """ + return self.conn.get_session_endpoint(applicationId=application_id, sessionId=session_id) + + def terminate_session(self, application_id: str, session_id: str) -> None: + """Terminate an interactive session.""" + self.conn.terminate_session(applicationId=application_id, sessionId=session_id) + def is_connection_being_updated_exception(exception: BaseException) -> bool: return ( diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py index 14d91d613509b..801c2cddcd343 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py @@ -44,9 +44,11 @@ EmrServerlessCancelJobsTrigger, EmrServerlessCreateApplicationTrigger, EmrServerlessDeleteApplicationTrigger, + EmrServerlessSessionTrigger, EmrServerlessStartApplicationTrigger, EmrServerlessStartJobTrigger, EmrServerlessStopApplicationTrigger, + EmrServerlessStopSessionTrigger, EmrTerminateJobFlowTrigger, ) from airflow.providers.amazon.aws.utils import validate_execute_complete_event @@ -62,7 +64,7 @@ inject_parent_job_information_into_emr_serverless_properties, inject_transport_information_into_emr_serverless_properties, ) -from airflow.providers.common.compat.sdk import AirflowException, conf +from airflow.providers.common.compat.sdk import AirflowException, conf, mask_secret from airflow.utils.helpers import exactly_one, prune_dict if TYPE_CHECKING: @@ -1813,3 +1815,252 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if validated_event["status"] != "success": raise AirflowException(f"Error deleting EMR Serverless application: {validated_event}") self.log.info("EMR serverless application %s deleted successfully", self.application_id) + + +class EmrServerlessStartSessionOperator(AwsBaseOperator[EmrServerlessHook]): + """ + Start an EMR Serverless interactive session and wait until it is ready. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessStartSessionOperator` + + :param application_id: ID of the EMR Serverless application to run the session on. + :param execution_role_arn: ARN of the IAM role the session assumes to access data. + :param name: An optional name for the session. + :param idle_timeout_minutes: Auto-stop the session after this many idle minutes. + :param configuration_overrides: Optional Spark/monitoring configuration overrides. + :param wait_for_completion: If True, wait for the session to be ready before returning. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param waiter_max_attempts: Number of times the waiter should poll the session to check the state. + :param waiter_delay: Number of seconds between polling the state of the session. + :param deferrable: If True, the operator will wait asynchronously for the session to be ready. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + """ + + aws_hook_class = EmrServerlessHook + template_fields: Sequence[str] = aws_template_fields( + "application_id", + "execution_role_arn", + "name", + "idle_timeout_minutes", + "configuration_overrides", + ) + + def __init__( + self, + *, + application_id: str, + execution_role_arn: str, + name: str | None = None, + idle_timeout_minutes: int | None = None, + configuration_overrides: dict | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 10, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.application_id = application_id + self.execution_role_arn = execution_role_arn + self.name = name + self.idle_timeout_minutes = idle_timeout_minutes + self.configuration_overrides = configuration_overrides + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> dict: + session_id = self.hook.start_session( + application_id=self.application_id, + execution_role_arn=self.execution_role_arn, + name=self.name, + idle_timeout_minutes=self.idle_timeout_minutes, + configuration_overrides=self.configuration_overrides, + ) + self.log.info("Started EMR Serverless session %s", session_id) + + if self.deferrable: + self.defer( + trigger=EmrServerlessSessionTrigger( + application_id=self.application_id, + session_id=session_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + method_name="execute_complete", + ) + + if self.wait_for_completion: + wait( + waiter=self.hook.get_waiter("serverless_session_ready"), + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + args={"applicationId": self.application_id, "sessionId": session_id}, + failure_message="EMR Serverless session failed to start", + status_message="EMR Serverless session status is", + status_args=["session.state", "session.stateDetails"], + ) + return {"application_id": self.application_id, "session_id": session_id} + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict: + validated_event = validate_execute_complete_event(event) + + if validated_event["status"] != "success": + raise RuntimeError(f"Error starting EMR Serverless session: {validated_event}") + self.log.info("EMR Serverless session %s started", validated_event["session_id"]) + return {"application_id": self.application_id, "session_id": validated_event["session_id"]} + + +class EmrServerlessGetSessionEndpointOperator(AwsBaseOperator[EmrServerlessHook]): + """ + Return a fresh Spark Connect endpoint and auth token for a running interactive session. + + The returned ``auth_token`` is short-lived (about one hour), so this operator should run + immediately before the task that connects to the session. The token is registered with + Airflow's secrets masker so it is redacted from task logs and the rendered XCom value. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessGetSessionEndpointOperator` + + :param application_id: ID of the EMR Serverless application. + :param session_id: ID of the interactive session. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + """ + + aws_hook_class = EmrServerlessHook + template_fields: Sequence[str] = aws_template_fields( + "application_id", + "session_id", + ) + + def __init__( + self, + *, + application_id: str, + session_id: str, + **kwargs, + ): + super().__init__(**kwargs) + self.application_id = application_id + self.session_id = session_id + + def execute(self, context: Context) -> dict: + response = self.hook.get_session_endpoint(self.application_id, self.session_id) + # The auth token is a short-lived credential. Register it with the secrets masker so + # it is redacted from task logs and the rendered XCom value in the UI. + if auth_token := response.get("authToken"): + mask_secret(auth_token) + self.log.info("Resolved Spark Connect endpoint for session %s", self.session_id) + return { + "endpoint": response.get("endpoint"), + "auth_token": auth_token, + "auth_token_expires_at": response.get("authTokenExpiresAt"), + } + + +class EmrServerlessStopSessionOperator(AwsBaseOperator[EmrServerlessHook]): + """ + Terminate an EMR Serverless interactive session. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:EmrServerlessStopSessionOperator` + + :param application_id: ID of the EMR Serverless application. + :param session_id: ID of the interactive session to terminate. + :param wait_for_completion: If True, wait for the session to terminate before returning. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param waiter_max_attempts: Number of times the waiter should poll the session to check the state. + :param waiter_delay: Number of seconds between polling the state of the session. + :param deferrable: If True, the operator will wait asynchronously for the session to terminate. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + """ + + aws_hook_class = EmrServerlessHook + template_fields: Sequence[str] = aws_template_fields( + "application_id", + "session_id", + ) + + def __init__( + self, + *, + application_id: str, + session_id: str, + wait_for_completion: bool = True, + waiter_delay: int = 10, + waiter_max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.application_id = application_id + self.session_id = session_id + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> None: + self.log.info("Terminating EMR Serverless session %s", self.session_id) + self.hook.terminate_session(self.application_id, self.session_id) + + if self.deferrable: + self.defer( + trigger=EmrServerlessStopSessionTrigger( + application_id=self.application_id, + session_id=self.session_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay), + method_name="execute_complete", + ) + + if self.wait_for_completion: + wait( + waiter=self.hook.get_waiter("serverless_session_terminated"), + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + args={"applicationId": self.application_id, "sessionId": self.session_id}, + failure_message="EMR Serverless session failed to terminate", + status_message="EMR Serverless session status is", + status_args=["session.state", "session.stateDetails"], + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + validated_event = validate_execute_complete_event(event) + + if validated_event["status"] != "success": + raise RuntimeError(f"Error terminating EMR Serverless session: {validated_event}") + self.log.info("EMR Serverless session %s terminated", self.session_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/emr.py index f9a632893882e..cbf41eb41e77c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/emr.py @@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.emr import ( EmrContainerTrigger, + EmrServerlessSessionTrigger, EmrStepSensorTrigger, EmrTerminateJobFlowTrigger, ) @@ -664,3 +665,81 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None raise AirflowException(f"Error while running job: {validated_event}") self.log.info("Job %s completed.", self.job_flow_id) + + +class EmrServerlessSessionSensor(AwsBaseSensor[EmrServerlessHook]): + """ + Poll the state of an interactive session until it reaches a target state; fails if it fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:EmrServerlessSessionSensor` + + :param application_id: application_id of the session to check the state of + :param session_id: session_id to check the state of + :param target_states: a set of states to wait for, defaults to {'STARTED', 'IDLE'} + :param max_attempts: Maximum number of poll attempts when running in deferrable mode. + :param deferrable: Run sensor in the deferrable mode. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.h + """ + + aws_hook_class = EmrServerlessHook + template_fields: Sequence[str] = aws_template_fields( + "application_id", + "session_id", + ) + + def __init__( + self, + *, + application_id: str, + session_id: str, + target_states: set | frozenset = frozenset(EmrServerlessHook.SESSION_SUCCESS_STATES), + max_attempts: int = 60, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs: Any, + ) -> None: + self.target_states = target_states + self.application_id = application_id + self.session_id = session_id + self.max_attempts = max_attempts + self.deferrable = deferrable + super().__init__(**kwargs) + + def poke(self, context: Context) -> bool: + state = self.hook.get_session_state(self.application_id, self.session_id) + + if state in EmrServerlessHook.SESSION_FAILURE_STATES: + raise RuntimeError(f"EMR Serverless session entered failure state: {state}") + + return state in self.target_states + + def execute(self, context: Context) -> None: + if not self.deferrable: + super().execute(context=context) + elif not self.poke(context): + self.defer( + timeout=timedelta(seconds=self.max_attempts * self.poke_interval), + trigger=EmrServerlessSessionTrigger( + application_id=self.application_id, + session_id=self.session_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + validated_event = validate_execute_complete_event(event) + + if validated_event["status"] != "success": + raise RuntimeError(f"Error while waiting for EMR Serverless session: {validated_event}") + self.log.info("EMR Serverless session %s reached a ready state", self.session_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py index 48ddb4b0c3197..bfdc0b67e6a92 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py @@ -623,3 +623,79 @@ def hook(self) -> AwsGenericHook: def hook_instance(self) -> AwsGenericHook: """This property is added for backward compatibility.""" return self.hook() + + +class EmrServerlessSessionTrigger(AwsBaseWaiterTrigger): + """ + Poll an EMR Serverless interactive session until it reaches a ready state. + + :param application_id: The ID of the EMR Serverless application. + :param session_id: The ID of the interactive session being polled. + :param waiter_delay: polling period in seconds to check for the status + :param waiter_max_attempts: The maximum number of attempts to be made + :param aws_conn_id: Reference to AWS connection id + """ + + def __init__( + self, + *, + application_id: str, + session_id: str, + waiter_delay: int = 10, + waiter_max_attempts: int = 60, + aws_conn_id: str | None = "aws_default", + ) -> None: + super().__init__( + serialized_fields={"application_id": application_id, "session_id": session_id}, + waiter_name="serverless_session_ready", + waiter_args={"applicationId": application_id, "sessionId": session_id}, + failure_message="EMR Serverless session failed to start", + status_message="EMR Serverless session status is", + status_queries=["session.state"], + return_key="session_id", + return_value=session_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return EmrServerlessHook(self.aws_conn_id) + + +class EmrServerlessStopSessionTrigger(AwsBaseWaiterTrigger): + """ + Poll an EMR Serverless interactive session until it is terminated. + + :param application_id: The ID of the EMR Serverless application. + :param session_id: The ID of the interactive session being polled. + :param waiter_delay: polling period in seconds to check for the status + :param waiter_max_attempts: The maximum number of attempts to be made + :param aws_conn_id: Reference to AWS connection id + """ + + def __init__( + self, + *, + application_id: str, + session_id: str, + waiter_delay: int = 10, + waiter_max_attempts: int = 60, + aws_conn_id: str | None = "aws_default", + ) -> None: + super().__init__( + serialized_fields={"application_id": application_id, "session_id": session_id}, + waiter_name="serverless_session_terminated", + waiter_args={"applicationId": application_id, "sessionId": session_id}, + failure_message="EMR Serverless session failed to terminate", + status_message="EMR Serverless session status is", + status_queries=["session.state"], + return_key="session_id", + return_value=session_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return EmrServerlessHook(self.aws_conn_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/waiters/emr-serverless.json b/providers/amazon/src/airflow/providers/amazon/aws/waiters/emr-serverless.json index 4066109382a6a..64ef8306519fd 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/waiters/emr-serverless.json +++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/emr-serverless.json @@ -152,6 +152,62 @@ "state": "success" } ] + }, + "serverless_session_ready": { + "operation": "GetSession", + "delay": 10, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "session.state", + "expected": "STARTED", + "state": "success" + }, + { + "matcher": "path", + "argument": "session.state", + "expected": "IDLE", + "state": "success" + }, + { + "matcher": "path", + "argument": "session.state", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "session.state", + "expected": "TERMINATING", + "state": "failure" + }, + { + "matcher": "path", + "argument": "session.state", + "expected": "TERMINATED", + "state": "failure" + } + ] + }, + "serverless_session_terminated": { + "operation": "GetSession", + "delay": 10, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "session.state", + "expected": "TERMINATED", + "state": "success" + }, + { + "matcher": "path", + "argument": "session.state", + "expected": "FAILED", + "state": "failure" + } + ] } } } diff --git a/providers/amazon/tests/system/amazon/aws/example_emr_serverless_session.py b/providers/amazon/tests/system/amazon/aws/example_emr_serverless_session.py new file mode 100644 index 0000000000000..7a8de9ab9e724 --- /dev/null +++ b/providers/amazon/tests/system/amazon/aws/example_emr_serverless_session.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.providers.amazon.aws.operators.emr import ( + EmrServerlessCreateApplicationOperator, + EmrServerlessDeleteApplicationOperator, + EmrServerlessGetSessionEndpointOperator, + EmrServerlessStartSessionOperator, + EmrServerlessStopSessionOperator, +) +from airflow.providers.amazon.aws.sensors.emr import EmrServerlessSessionSensor +from airflow.providers.common.compat.sdk import DAG, TriggerRule, chain + +from system.amazon.aws.utils import SystemTestContextBuilder + +DAG_ID = "example_emr_serverless_session" + +# Externally fetched variables: +ROLE_ARN_KEY = "ROLE_ARN" + +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "emr-serverless"], +) as dag: + test_context = sys_test_context_task() + role_arn = test_context[ROLE_ARN_KEY] + + create_app = EmrServerlessCreateApplicationOperator( + task_id="create_app", + release_label="emr-7.13.0", + job_type="SPARK", + config={ + "name": "session-systest", + # Interactive sessions (required for the StartSession API) are only available + # on emr-7.13.0+ and must be explicitly enabled on the application. + "interactiveConfiguration": {"sessionEnabled": True}, + }, + ) + application_id = create_app.output + + # [START howto_operator_emr_serverless_start_session] + start_session = EmrServerlessStartSessionOperator( + task_id="start_session", + application_id=application_id, + execution_role_arn=role_arn, + idle_timeout_minutes=5, + ) + # [END howto_operator_emr_serverless_start_session] + session_id = start_session.output["session_id"] + + # [START howto_sensor_emr_serverless_session] + wait_for_session = EmrServerlessSessionSensor( + task_id="wait_for_session", + application_id=application_id, + session_id=session_id, + ) + # [END howto_sensor_emr_serverless_session] + + # [START howto_operator_emr_serverless_get_session_endpoint] + get_endpoint = EmrServerlessGetSessionEndpointOperator( + task_id="get_endpoint", + application_id=application_id, + session_id=session_id, + ) + # [END howto_operator_emr_serverless_get_session_endpoint] + + # [START howto_operator_emr_serverless_stop_session] + stop_session = EmrServerlessStopSessionOperator( + task_id="stop_session", + application_id=application_id, + session_id=session_id, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END howto_operator_emr_serverless_stop_session] + + delete_app = EmrServerlessDeleteApplicationOperator( + task_id="delete_app", + application_id=application_id, + trigger_rule=TriggerRule.ALL_DONE, + ) + + chain( + # TEST SETUP + test_context, + create_app, + # TEST BODY + start_session, + wait_for_session, + get_endpoint, + # TEST TEARDOWN + stop_session, + delete_app, + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst) +test_run = get_test_run(dag) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_emr_serverless.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_emr_serverless.py index 75e4c5f897425..5502e7c297c67 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_emr_serverless.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_emr_serverless.py @@ -75,3 +75,70 @@ def test_cancel_jobs_but_no_jobs(self, conn_mock: MagicMock): # nothing very interesting should happen conn_mock.assert_called_once() + + +class TestEmrServerlessHookSession: + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_start_session_minimal(self, conn_mock: MagicMock): + conn_mock().start_session.return_value = {"sessionId": "sess-1"} + hook = EmrServerlessHook(aws_conn_id="aws_default") + + session_id = hook.start_session(application_id="app", execution_role_arn="role") + + assert session_id == "sess-1" + conn_mock().start_session.assert_called_once_with(applicationId="app", executionRoleArn="role") + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_start_session_with_optional_params(self, conn_mock: MagicMock): + conn_mock().start_session.return_value = {"sessionId": "sess-2"} + hook = EmrServerlessHook(aws_conn_id="aws_default") + + session_id = hook.start_session( + application_id="app", + execution_role_arn="role", + name="my-session", + idle_timeout_minutes=15, + configuration_overrides={"applicationConfiguration": []}, + ) + + assert session_id == "sess-2" + conn_mock().start_session.assert_called_once_with( + applicationId="app", + executionRoleArn="role", + name="my-session", + idleTimeoutMinutes=15, + configurationOverrides={"applicationConfiguration": []}, + ) + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_get_session_state(self, conn_mock: MagicMock): + conn_mock().get_session.return_value = {"session": {"state": "STARTED"}} + hook = EmrServerlessHook(aws_conn_id="aws_default") + + assert hook.get_session_state("app", "sess-1") == "STARTED" + conn_mock().get_session.assert_called_once_with(applicationId="app", sessionId="sess-1") + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_get_session_endpoint_returns_raw_response(self, conn_mock: MagicMock): + raw = { + "applicationId": "app", + "sessionId": "sess-1", + "endpoint": "https://sess-1.example.amazonaws.com", + "authToken": "secret-token", + "authTokenExpiresAt": "2026-01-01T00:00:00Z", + } + conn_mock().get_session_endpoint.return_value = raw + hook = EmrServerlessHook(aws_conn_id="aws_default") + + result = hook.get_session_endpoint("app", "sess-1") + + assert result is raw + conn_mock().get_session_endpoint.assert_called_once_with(applicationId="app", sessionId="sess-1") + + @patch.object(EmrServerlessHook, "conn", new_callable=PropertyMock) + def test_terminate_session(self, conn_mock: MagicMock): + hook = EmrServerlessHook(aws_conn_id="aws_default") + + hook.terminate_session("app", "sess-1") + + conn_mock().terminate_session.assert_called_once_with(applicationId="app", sessionId="sess-1") diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless_session.py b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless_session.py new file mode 100644 index 0000000000000..a8d2280dc1c13 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless_session.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import TaskDeferred +from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook +from airflow.providers.amazon.aws.operators.emr import ( + EmrServerlessGetSessionEndpointOperator, + EmrServerlessStartSessionOperator, + EmrServerlessStopSessionOperator, +) +from airflow.providers.amazon.aws.triggers.emr import ( + EmrServerlessSessionTrigger, + EmrServerlessStopSessionTrigger, +) + +APP_ID = "app-123" +SESSION_ID = "sess-abc" +ROLE = "arn:aws:iam::111122223333:role/emr-exec" +WAIT = "airflow.providers.amazon.aws.operators.emr.wait" + + +class TestEmrServerlessStartSessionOperator: + @mock.patch(WAIT) + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "start_session") + def test_start_and_wait(self, start_session, get_waiter, wait_mock): + start_session.return_value = SESSION_ID + op = EmrServerlessStartSessionOperator( + task_id="start", + application_id=APP_ID, + execution_role_arn=ROLE, + idle_timeout_minutes=15, + ) + result = op.execute({}) + + start_session.assert_called_once_with( + application_id=APP_ID, + execution_role_arn=ROLE, + name=None, + idle_timeout_minutes=15, + configuration_overrides=None, + ) + wait_mock.assert_called_once() + get_waiter.assert_called_once_with("serverless_session_ready") + assert result == {"application_id": APP_ID, "session_id": SESSION_ID} + + @mock.patch(WAIT) + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "start_session") + def test_no_wait(self, start_session, get_waiter, wait_mock): + start_session.return_value = SESSION_ID + op = EmrServerlessStartSessionOperator( + task_id="start", + application_id=APP_ID, + execution_role_arn=ROLE, + wait_for_completion=False, + ) + op.execute({}) + wait_mock.assert_not_called() + + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "start_session") + def test_deferrable_defers(self, start_session, get_waiter): + start_session.return_value = SESSION_ID + op = EmrServerlessStartSessionOperator( + task_id="start", + application_id=APP_ID, + execution_role_arn=ROLE, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as deferred: + op.execute({}) + + assert isinstance(deferred.value.trigger, EmrServerlessSessionTrigger) + get_waiter.assert_not_called() + + def test_execute_complete_success(self): + op = EmrServerlessStartSessionOperator( + task_id="start", application_id=APP_ID, execution_role_arn=ROLE + ) + result = op.execute_complete({}, {"status": "success", "session_id": SESSION_ID}) + assert result == {"application_id": APP_ID, "session_id": SESSION_ID} + + def test_execute_complete_failure_raises(self): + op = EmrServerlessStartSessionOperator( + task_id="start", application_id=APP_ID, execution_role_arn=ROLE + ) + with pytest.raises(RuntimeError): + op.execute_complete({}, {"status": "failure", "session_id": SESSION_ID}) + + +class TestEmrServerlessGetSessionEndpointOperator: + RAW_RESPONSE = { + "applicationId": APP_ID, + "sessionId": SESSION_ID, + "endpoint": "https://sess-abc.s.emr-serverless-services.us-east-1.amazonaws.com", + "authToken": "tok", + "authTokenExpiresAt": None, + } + + @mock.patch.object(EmrServerlessHook, "get_session_endpoint") + def test_returns_transformed_endpoint(self, get_session_endpoint): + get_session_endpoint.return_value = self.RAW_RESPONSE + op = EmrServerlessGetSessionEndpointOperator( + task_id="ep", application_id=APP_ID, session_id=SESSION_ID + ) + out = op.execute({}) + get_session_endpoint.assert_called_once_with(APP_ID, SESSION_ID) + assert out == { + "endpoint": "https://sess-abc.s.emr-serverless-services.us-east-1.amazonaws.com", + "auth_token": "tok", + "auth_token_expires_at": None, + } + + @mock.patch("airflow.providers.amazon.aws.operators.emr.mask_secret") + @mock.patch.object(EmrServerlessHook, "get_session_endpoint") + def test_auth_token_is_masked(self, get_session_endpoint, mask_secret_mock): + get_session_endpoint.return_value = self.RAW_RESPONSE + op = EmrServerlessGetSessionEndpointOperator( + task_id="ep", application_id=APP_ID, session_id=SESSION_ID + ) + op.execute({}) + mask_secret_mock.assert_called_once_with("tok") + + +class TestEmrServerlessStopSessionOperator: + @mock.patch(WAIT) + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "terminate_session") + def test_terminate_and_wait(self, terminate_session, get_waiter, wait_mock): + op = EmrServerlessStopSessionOperator(task_id="stop", application_id=APP_ID, session_id=SESSION_ID) + op.execute({}) + + terminate_session.assert_called_once_with(APP_ID, SESSION_ID) + wait_mock.assert_called_once() + get_waiter.assert_called_once_with("serverless_session_terminated") + + @mock.patch(WAIT) + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "terminate_session") + def test_terminate_no_wait(self, terminate_session, get_waiter, wait_mock): + op = EmrServerlessStopSessionOperator( + task_id="stop", + application_id=APP_ID, + session_id=SESSION_ID, + wait_for_completion=False, + ) + op.execute({}) + wait_mock.assert_not_called() + + @mock.patch(WAIT) + @mock.patch.object(EmrServerlessHook, "get_waiter") + @mock.patch.object(EmrServerlessHook, "terminate_session") + def test_deferrable_defers(self, terminate_session, get_waiter, wait_mock): + op = EmrServerlessStopSessionOperator( + task_id="stop", + application_id=APP_ID, + session_id=SESSION_ID, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as deferred: + op.execute({}) + + terminate_session.assert_called_once_with(APP_ID, SESSION_ID) + assert isinstance(deferred.value.trigger, EmrServerlessStopSessionTrigger) + wait_mock.assert_not_called() + + def test_execute_complete_success(self): + op = EmrServerlessStopSessionOperator(task_id="stop", application_id=APP_ID, session_id=SESSION_ID) + assert op.execute_complete({}, {"status": "success", "session_id": SESSION_ID}) is None + + def test_execute_complete_failure_raises(self): + op = EmrServerlessStopSessionOperator(task_id="stop", application_id=APP_ID, session_id=SESSION_ID) + with pytest.raises(RuntimeError): + op.execute_complete({}, {"status": "failure", "session_id": SESSION_ID}) diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_emr_serverless_session.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_emr_serverless_session.py new file mode 100644 index 0000000000000..442b736cce3b7 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_emr_serverless_session.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from airflow.exceptions import TaskDeferred +from airflow.providers.amazon.aws.sensors.emr import EmrServerlessSessionSensor +from airflow.providers.amazon.aws.triggers.emr import EmrServerlessSessionTrigger + + +class TestEmrServerlessSessionSensor: + def setup_method(self): + self.app_id = "app-123" + self.session_id = "sess-abc" + self.sensor = EmrServerlessSessionSensor( + task_id="test_emr_serverless_session_sensor", + application_id=self.app_id, + session_id=self.session_id, + aws_conn_id="aws_default", + ) + + def set_session_state(self, state: str): + self.mock_hook = MagicMock() + self.mock_hook.get_session_state.return_value = state + self.sensor.hook = self.mock_hook + + def assert_get_session_state_called_once(self): + self.mock_hook.get_session_state.assert_called_once_with(self.app_id, self.session_id) + + @pytest.mark.parametrize( + ("state", "expected_result"), + [ + ("SUBMITTED", False), + ("STARTING", False), + ("STARTED", True), + ("IDLE", True), + ], + ) + def test_poke_returns_expected_result_for_states(self, state, expected_result): + self.set_session_state(state) + assert self.sensor.poke(None) == expected_result + self.assert_get_session_state_called_once() + + @pytest.mark.parametrize("state", ["FAILED", "TERMINATING", "TERMINATED"]) + def test_poke_raises_for_failure_states(self, state): + self.set_session_state(state) + with pytest.raises(RuntimeError, match=f"failure state: {state}"): + self.sensor.poke(None) + self.assert_get_session_state_called_once() + + def test_deferrable_defers(self): + sensor = EmrServerlessSessionSensor( + task_id="deferred_session_sensor", + application_id=self.app_id, + session_id=self.session_id, + aws_conn_id="aws_default", + deferrable=True, + ) + mock_hook = MagicMock() + mock_hook.get_session_state.return_value = "STARTING" + sensor.hook = mock_hook + + with pytest.raises(TaskDeferred) as deferred: + sensor.execute(None) + + assert isinstance(deferred.value.trigger, EmrServerlessSessionTrigger) + + def test_execute_complete_failure_raises(self): + with pytest.raises(RuntimeError): + self.sensor.execute_complete({}, {"status": "failure"}) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py index 2e8ea174c1b5e..d974c2dbd031c 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py @@ -30,9 +30,11 @@ EmrServerlessCancelJobsTrigger, EmrServerlessCreateApplicationTrigger, EmrServerlessDeleteApplicationTrigger, + EmrServerlessSessionTrigger, EmrServerlessStartApplicationTrigger, EmrServerlessStartJobTrigger, EmrServerlessStopApplicationTrigger, + EmrServerlessStopSessionTrigger, EmrStepSensorTrigger, EmrTerminateJobFlowTrigger, ) @@ -572,3 +574,52 @@ def test_serialization(self): "waiter_max_attempts": 60, "aws_conn_id": "aws_default", } + + +class TestEmrServerlessSessionTrigger: + def test_serialization(self): + trigger = EmrServerlessSessionTrigger( + application_id="test_application_id", + session_id="test_session_id", + waiter_delay=10, + waiter_max_attempts=60, + aws_conn_id="aws_default", + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessSessionTrigger" + assert kwargs == { + "application_id": "test_application_id", + "session_id": "test_session_id", + "waiter_delay": 10, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } + + def test_hook_returns_serverless_hook(self): + from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook + + trigger = EmrServerlessSessionTrigger( + application_id="test_application_id", + session_id="test_session_id", + ) + assert isinstance(trigger.hook(), EmrServerlessHook) + + +class TestEmrServerlessStopSessionTrigger: + def test_serialization(self): + trigger = EmrServerlessStopSessionTrigger( + application_id="test_application_id", + session_id="test_session_id", + waiter_delay=10, + waiter_max_attempts=60, + aws_conn_id="aws_default", + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.emr.EmrServerlessStopSessionTrigger" + assert kwargs == { + "application_id": "test_application_id", + "session_id": "test_session_id", + "waiter_delay": 10, + "waiter_max_attempts": 60, + "aws_conn_id": "aws_default", + } diff --git a/providers/common/compat/src/airflow/providers/common/compat/sdk.py b/providers/common/compat/src/airflow/providers/common/compat/sdk.py index 93174df7b2a28..9d646e218ae2c 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/sdk.py +++ b/providers/common/compat/src/airflow/providers/common/compat/sdk.py @@ -298,6 +298,11 @@ "airflow.sdk.execution_time.secrets_masker", "airflow.utils.log.secrets_masker", ), + "mask_secret": ( + "airflow.sdk._shared.secrets_masker", + "airflow.sdk.execution_time.secrets_masker", + "airflow.utils.log.secrets_masker", + ), # ============================================================================ # Listeners # ============================================================================