Skip to content

Commit 0bec208

Browse files
Narrohagaviruthen
andauthored
feat: add evaluator tagging for jumpstart models (#5413)
* add evaluator tagging for jumpstart models * fix bug for extending tags * bug fix for js tags * add unit test for js evaluator tagging --------- Co-authored-by: aviruthen <[email protected]>
1 parent d089d40 commit 0bec208

File tree

3 files changed

+64
-11
lines changed

3 files changed

+64
-11
lines changed

sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
import logging
1111
import re
12-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
12+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1313

1414
from pydantic import BaseModel, validator
1515

16+
from sagemaker.core.common_utils import TagsDict
1617
from sagemaker.core.resources import ModelPackageGroup, ModelPackage
1718
from sagemaker.core.shapes import VpcConfig
1819

@@ -413,6 +414,13 @@ def _source_model_package_arn(self) -> Optional[str]:
413414
"""Get the resolved source model package ARN (None for JumpStart models)."""
414415
info = self._get_resolved_model_info()
415416
return info.source_model_package_arn if info else None
417+
418+
@property
419+
def _is_jumpstart_model(self) -> bool:
420+
"""Determine if model is a JumpStart model"""
421+
from sagemaker.train.common_utils.model_resolution import _ModelType
422+
info = self._get_resolved_model_info()
423+
return info.model_type == _ModelType.JUMPSTART
416424

417425
def _infer_model_package_group_arn(self) -> Optional[str]:
418426
"""Infer model package group ARN from source model package ARN.
@@ -797,6 +805,12 @@ def _start_execution(
797805
EvaluationPipelineExecution: Started execution object
798806
"""
799807
from .execution import EvaluationPipelineExecution
808+
809+
tags: List[TagsDict] = []
810+
811+
if self._is_jumpstart_model:
812+
from sagemaker.core.jumpstart.utils import add_jumpstart_model_info_tags
813+
tags = add_jumpstart_model_info_tags(tags, self.model, "*")
800814

801815
execution = EvaluationPipelineExecution.start(
802816
eval_type=eval_type,
@@ -805,7 +819,8 @@ def _start_execution(
805819
role_arn=role_arn,
806820
s3_output_path=self.s3_output_path,
807821
session=self.sagemaker_session.boto_session if hasattr(self.sagemaker_session, 'boto_session') else None,
808-
region=region
822+
region=region,
823+
tags=tags
809824
)
810825

811826
return execution

sagemaker-train/src/sagemaker/train/evaluate/execution.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# Third-party imports
1717
from botocore.exceptions import ClientError
1818
from pydantic import BaseModel, Field
19+
from sagemaker.core.common_utils import TagsDict
1920
from sagemaker.core.helper.session_helper import Session
2021
from sagemaker.core.resources import Pipeline, PipelineExecution, Tag
2122
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
@@ -38,6 +39,7 @@ def _create_evaluation_pipeline(
3839
pipeline_definition: str,
3940
session: Optional[Any] = None,
4041
region: Optional[str] = None,
42+
tags: Optional[List[TagsDict]] = [],
4143
) -> Any:
4244
"""Helper method to create a SageMaker pipeline for evaluation.
4345
@@ -49,6 +51,7 @@ def _create_evaluation_pipeline(
4951
pipeline_definition (str): JSON pipeline definition (Jinja2 template).
5052
session (Optional[Any]): SageMaker session object.
5153
region (Optional[str]): AWS region.
54+
tags (Optional[List[TagsDict]]): List of tags to include in pipeline
5255
5356
Returns:
5457
Any: Created Pipeline instance (ready for execution).
@@ -65,9 +68,9 @@ def _create_evaluation_pipeline(
6568
resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)
6669

6770
# Create tags for the pipeline
68-
tags = [
71+
tags.extend([
6972
{"key": _TAG_SAGEMAKER_MODEL_EVALUATION, "value": "true"}
70-
]
73+
])
7174

7275
pipeline = Pipeline.create(
7376
pipeline_name=pipeline_name,
@@ -163,7 +166,8 @@ def _get_or_create_pipeline(
163166
pipeline_definition: str,
164167
role_arn: str,
165168
session: Optional[Session] = None,
166-
region: Optional[str] = None
169+
region: Optional[str] = None,
170+
create_tags: Optional[List[TagsDict]] = [],
167171
) -> Pipeline:
168172
"""Get existing pipeline or create/update it.
169173
@@ -177,6 +181,7 @@ def _get_or_create_pipeline(
177181
role_arn: IAM role ARN for pipeline execution
178182
session: Boto3 session (optional)
179183
region: AWS region (optional)
184+
create_tags (Optional[List[TagsDict]]): List of tags to include in pipeline
180185
181186
Returns:
182187
Pipeline instance (existing updated or newly created)
@@ -225,19 +230,19 @@ def _get_or_create_pipeline(
225230

226231
# No matching pipeline found, create new one
227232
logger.info(f"No existing pipeline found with prefix {pipeline_name_prefix}, creating new one")
228-
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
233+
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
229234

230235
except ClientError as e:
231236
error_code = e.response['Error']['Code']
232237
if "ResourceNotFound" in error_code:
233-
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
238+
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
234239
else:
235240
raise
236241

237242
except Exception as e:
238243
# If search fails for other reasons, try to create
239244
logger.info(f"Error searching for pipeline ({str(e)}), attempting to create new pipeline")
240-
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region)
245+
return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
241246

242247

243248
def _start_pipeline_execution(
@@ -505,7 +510,8 @@ def start(
505510
role_arn: str,
506511
s3_output_path: Optional[str] = None,
507512
session: Optional[Session] = None,
508-
region: Optional[str] = None
513+
region: Optional[str] = None,
514+
tags: Optional[List[TagsDict]] = [],
509515
) -> 'EvaluationPipelineExecution':
510516
"""Create sagemaker pipeline execution. Optionally creates pipeline.
511517
@@ -517,6 +523,7 @@ def start(
517523
s3_output_path (Optional[str]): S3 location where evaluation results are stored.
518524
session (Optional[Session]): Boto3 session for API calls.
519525
region (Optional[str]): AWS region for the pipeline.
526+
tags (Optional[List[TagsDict]]): List of tags to include in pipeline
520527
521528
Returns:
522529
EvaluationPipelineExecution: Started pipeline execution instance.
@@ -547,7 +554,8 @@ def start(
547554
pipeline_definition=pipeline_definition,
548555
role_arn=role_arn,
549556
session=session,
550-
region=region
557+
region=region,
558+
create_tags=tags,
551559
)
552560

553561
# Start pipeline execution via boto3

sagemaker-train/tests/unit/train/evaluate/test_execution.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,37 @@ def test_create_pipeline_when_not_found(self, mock_pipeline_class, mock_create,
299299
DEFAULT_ROLE,
300300
DEFAULT_PIPELINE_DEFINITION,
301301
mock_session,
302-
DEFAULT_REGION
302+
DEFAULT_REGION,
303+
[]
304+
)
305+
assert result == mock_pipeline
306+
307+
@patch("sagemaker.train.evaluate.execution._create_evaluation_pipeline")
308+
@patch("sagemaker.train.evaluate.execution.Pipeline")
309+
def test_create_pipeline_when_not_found_with_jumpstart_tags(self, mock_pipeline_class, mock_create, mock_session):
310+
"""Test creating pipeline when it doesn't exist."""
311+
error_response = {"Error": {"Code": "ResourceNotFound"}}
312+
mock_pipeline_class.get.side_effect = ClientError(error_response, "DescribePipeline")
313+
mock_pipeline = MagicMock()
314+
mock_create.return_value = mock_pipeline
315+
create_tags = [{"key": "sagemaker-sdk:jumpstart-model-id", "value": "dummy-js-model-id"}]
316+
317+
result = _get_or_create_pipeline(
318+
eval_type=EvalType.BENCHMARK,
319+
pipeline_definition=DEFAULT_PIPELINE_DEFINITION,
320+
role_arn=DEFAULT_ROLE,
321+
session=mock_session,
322+
region=DEFAULT_REGION,
323+
create_tags=create_tags
324+
)
325+
326+
mock_create.assert_called_once_with(
327+
EvalType.BENCHMARK,
328+
DEFAULT_ROLE,
329+
DEFAULT_PIPELINE_DEFINITION,
330+
mock_session,
331+
DEFAULT_REGION,
332+
create_tags
303333
)
304334
assert result == mock_pipeline
305335

0 commit comments

Comments
 (0)