Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import logging
import re
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import BaseModel, validator

from sagemaker.core.common_utils import TagsDict
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
from sagemaker.core.shapes import VpcConfig

Expand Down Expand Up @@ -413,6 +414,13 @@ def _source_model_package_arn(self) -> Optional[str]:
"""Get the resolved source model package ARN (None for JumpStart models)."""
info = self._get_resolved_model_info()
return info.source_model_package_arn if info else None

@property
def _is_jumpstart_model(self) -> bool:
"""Determine if model is a JumpStart model"""
from sagemaker.train.common_utils.model_resolution import _ModelType
info = self._get_resolved_model_info()
return info.model_type == _ModelType.JUMPSTART

def _infer_model_package_group_arn(self) -> Optional[str]:
"""Infer model package group ARN from source model package ARN.
Expand Down Expand Up @@ -797,6 +805,12 @@ def _start_execution(
EvaluationPipelineExecution: Started execution object
"""
from .execution import EvaluationPipelineExecution

tags: List[TagsDict] = []

if self._is_jumpstart_model:
from sagemaker.core.jumpstart.utils import add_jumpstart_model_info_tags
tags = add_jumpstart_model_info_tags(tags, self.model, "*")

execution = EvaluationPipelineExecution.start(
eval_type=eval_type,
Expand All @@ -805,7 +819,8 @@ def _start_execution(
role_arn=role_arn,
s3_output_path=self.s3_output_path,
session=self.sagemaker_session.boto_session if hasattr(self.sagemaker_session, 'boto_session') else None,
region=region
region=region,
tags=tags
)

return execution
Expand Down
24 changes: 16 additions & 8 deletions sagemaker-train/src/sagemaker/train/evaluate/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Third-party imports
from botocore.exceptions import ClientError
from pydantic import BaseModel, Field
from sagemaker.core.common_utils import TagsDict
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.resources import Pipeline, PipelineExecution, Tag
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
Expand All @@ -38,6 +39,7 @@ def _create_evaluation_pipeline(
pipeline_definition: str,
session: Optional[Any] = None,
region: Optional[str] = None,
tags: Optional[List[TagsDict]] = [],
) -> Any:
"""Helper method to create a SageMaker pipeline for evaluation.

Expand All @@ -49,6 +51,7 @@ def _create_evaluation_pipeline(
pipeline_definition (str): JSON pipeline definition (Jinja2 template).
session (Optional[Any]): SageMaker session object.
region (Optional[str]): AWS region.
tags (Optional[List[TagsDict]]): List of tags to include in pipeline

Returns:
Any: Created Pipeline instance (ready for execution).
Expand All @@ -65,9 +68,9 @@ def _create_evaluation_pipeline(
resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)

# Create tags for the pipeline
tags = [
tags.extend([
{"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"}
]
])

pipeline = Pipeline.create(
pipeline_name=pipeline_name,
Expand Down Expand Up @@ -163,7 +166,8 @@ def _get_or_create_pipeline(
pipeline_definition: str,
role_arn: str,
session: Optional[Session] = None,
region: Optional[str] = None
region: Optional[str] = None,
create_tags: Optional[List[TagsDict]] = [],
) -> Pipeline:
"""Get existing pipeline or create/update it.

Expand All @@ -177,6 +181,7 @@ def _get_or_create_pipeline(
role_arn: IAM role ARN for pipeline execution
session: Boto3 session (optional)
region: AWS region (optional)
create_tags (Optional[List[TagsDict]]): List of tags to include in pipeline

Returns:
Pipeline instance (existing updated or newly created)
Expand Down Expand Up @@ -225,19 +230,19 @@ def _get_or_create_pipeline(

# No matching pipeline found, create new one
logger.info(f"No existing pipeline found with prefix {pipeline_name_prefix}, creating new one")
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)

except ClientError as e:
error_code = e.response['Error']['Code']
if "ResourceNotFound" in error_code:
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
else:
raise

except Exception as e:
# If search fails for other reasons, try to create
logger.info(f"Error searching for pipeline ({str(e)}), attempting to create new pipeline")
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)


def _start_pipeline_execution(
Expand Down Expand Up @@ -505,7 +510,8 @@ def start(
role_arn: str,
s3_output_path: Optional[str] = None,
session: Optional[Session] = None,
region: Optional[str] = None
region: Optional[str] = None,
tags: Optional[List[TagsDict]] = [],
) -> 'EvaluationPipelineExecution':
"""Create sagemaker pipeline execution. Optionally creates pipeline.

Expand All @@ -517,6 +523,7 @@ def start(
s3_output_path (Optional[str]): S3 location where evaluation results are stored.
session (Optional[Session]): Boto3 session for API calls.
region (Optional[str]): AWS region for the pipeline.
tags (Optional[List[TagsDict]]): List of tags to include in pipeline

Returns:
EvaluationPipelineExecution: Started pipeline execution instance.
Expand Down Expand Up @@ -547,7 +554,8 @@ def start(
pipeline_definition=pipeline_definition,
role_arn=role_arn,
session=session,
region=region
region=region,
create_tags=tags,
)

# Start pipeline execution via boto3
Expand Down
32 changes: 31 additions & 1 deletion sagemaker-train/tests/unit/train/evaluate/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,37 @@ def test_create_pipeline_when_not_found(self, mock_pipeline_class, mock_create,
DEFAULT_ROLE,
DEFAULT_PIPELINE_DEFINITION,
mock_session,
DEFAULT_REGION
DEFAULT_REGION,
[]
)
assert result == mock_pipeline

@patch("sagemaker.train.evaluate.execution._create_evaluation_pipeline")
@patch("sagemaker.train.evaluate.execution.Pipeline")
def test_create_pipeline_when_not_found_with_jumpstart_tags(self, mock_pipeline_class, mock_create, mock_session):
"""Test creating pipeline when it doesn't exist."""
error_response = {"Error": {"Code": "ResourceNotFound"}}
mock_pipeline_class.get.side_effect = ClientError(error_response, "DescribePipeline")
mock_pipeline = MagicMock()
mock_create.return_value = mock_pipeline
create_tags = [{"key": "sagemaker-sdk:jumpstart-model-id", "value": "dummy-js-model-id"}]

result = _get_or_create_pipeline(
eval_type=EvalType.BENCHMARK,
pipeline_definition=DEFAULT_PIPELINE_DEFINITION,
role_arn=DEFAULT_ROLE,
session=mock_session,
region=DEFAULT_REGION,
create_tags=create_tags
)

mock_create.assert_called_once_with(
EvalType.BENCHMARK,
DEFAULT_ROLE,
DEFAULT_PIPELINE_DEFINITION,
mock_session,
DEFAULT_REGION,
create_tags
)
assert result == mock_pipeline

Expand Down
Loading