Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2adc4bd
Update image_uri_config, fw_utils and image_uris.py in sagemaker-core
zhaoqizqwang Dec 17, 2025
c08847d
Add ModelTrainer updates
zhaoqizqwang Dec 17, 2025
f20ee01
Update s3 bucket check in session_helper.py
zhaoqizqwang Dec 17, 2025
f04e341
fix: Map llama models to correct script
zhaoqizqwang Dec 17, 2025
c0c4002
fix: honor json serialization of HPs
zhaoqizqwang Dec 17, 2025
e402002
fix: clarify model monitor one time schedule bug
zhaoqizqwang Dec 17, 2025
7f97b16
fix: Allow import failure for internal _hashlib module
zhaoqizqwang Dec 17, 2025
c3871ff
Remove duplicate model_trainer.py
zhaoqizqwang Dec 17, 2025
3b07b4a
Add ignore_patterns in ModelTrainer to ignore specific files/folders
zhaoqizqwang Dec 17, 2025
47df7c0
Update instance type regex to also include hyphens
zhaoqizqwang Dec 17, 2025
ba2a6b4
chore: domain support for eu-isoe-west-1
zhaoqizqwang Dec 17, 2025
0e84f33
Fix: Object of type ModelLifeCycle is not JSON serializable
zhaoqizqwang Dec 17, 2025
5878547
fix: sanitize git clone repo input url
zhaoqizqwang Dec 17, 2025
c4ca393
Add support for MetricDefinitions in ModelTrainer
zhaoqizqwang Dec 17, 2025
e043098
feat: support pipeline versioning
zhaoqizqwang Dec 17, 2025
471a246
add eval custom lambda arn to hyperparameter
zhaoqizqwang Dec 17, 2025
3739879
Add Numpy 2.0 support
zhaoqizqwang Dec 17, 2025
b8b9aa3
fix: update get_execution_role to directly return the ExecutionRoleAr…
zhaoqizqwang Dec 17, 2025
2723f8f
HF Optimum Neuron 0.4.1 DLCs
zhaoqizqwang Dec 17, 2025
4d34e9b
Fix import error
zhaoqizqwang Dec 17, 2025
6feb6b6
Merge branch 'master' into master
zhaoqizqwang Dec 17, 2025
44d265b
Merge branch 'master' into master
aviruthen Dec 17, 2025
7acd0c7
Fix llama_v3 in sm_recipes
zhaoqizqwang Dec 17, 2025
425f868
Remove duplicate json in image_retriever
zhaoqizqwang Dec 17, 2025
7fe9139
Add todo notes in pipeline class
zhaoqizqwang Dec 18, 2025
d7944ad
Add V2 image_config_url unit tests
zhaoqizqwang Dec 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ pytest-xdist
mock
pydantic==2.11.9
pydantic_core==2.33.2
pandas
pandas>=2.3.0
numpy>=2.0.0, <3.0
scikit-learn==1.6.1
scipy
omegaconf
graphene
typing_extensions>=4.9.0
typing_extensions>=4.9.0
tensorflow>=2.16.2,<=2.19.0
3 changes: 2 additions & 1 deletion sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"us-isob-east-1": "sc2s.sgov.gov",
"us-isof-south-1": "csp.hci.ic.gov",
"us-isof-east-1": "csp.hci.ic.gov",
"eu-isoe-west-1": "cloud.adc-e.uk",
}

ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$"
Expand Down Expand Up @@ -1555,7 +1556,7 @@ def get_instance_type_family(instance_type: str) -> str:
"""
instance_type_family = ""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
if match is not None:
instance_type_family = match[1]
return instance_type_family
Expand Down
68 changes: 56 additions & 12 deletions sagemaker-core/src/sagemaker/core/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

from packaging import version

import sagemaker.core.common_utils as sagemaker_utils
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs
import sagemaker.core.common_utils as utils
from sagemaker.core.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
from sagemaker.core.instance_group import InstanceGroup
from sagemaker.core.s3 import s3_path_join
from sagemaker.core.s3.utils import s3_path_join
from sagemaker.core.session_settings import SessionSettings
from sagemaker.core.workflow import is_pipeline_variable
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.entities import PipelineVariable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -155,6 +155,9 @@
"2.3.1",
"2.4.1",
"2.5.1",
"2.6.0",
"2.7.1",
"2.8.0",
]

TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
Expand Down Expand Up @@ -455,7 +458,7 @@ def tar_and_upload_dir(

try:
source_files = _list_files_to_compress(script, directory) + dependencies
tar_file = sagemaker_utils.create_tar_file(
tar_file = utils.create_tar_file(
source_files, os.path.join(tmp, _TAR_SOURCE_FILENAME)
)

Expand Down Expand Up @@ -516,7 +519,7 @@ def framework_name_from_image(image_uri):
- str: The image tag
- str: If the TensorFlow image is script mode
"""
sagemaker_pattern = re.compile(sagemaker_utils.ECR_URI_PATTERN)
sagemaker_pattern = re.compile(utils.ECR_URI_PATTERN)
sagemaker_match = sagemaker_pattern.match(image_uri)
if sagemaker_match is None:
return None, None, None, None
Expand Down Expand Up @@ -595,7 +598,7 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
"""
name_from_image = f"/model_code/{int(time.time())}"
if not is_pipeline_variable(image):
name_from_image = sagemaker_utils.name_from_image(image)
name_from_image = utils.name_from_image(image)
return s3_path_join(code_location_key_prefix, model_name or name_from_image)


Expand Down Expand Up @@ -961,7 +964,7 @@ def validate_distribution_for_instance_type(instance_type, distribution):
"""
err_msg = ""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
if match and match[1].startswith("trn"):
keys = list(distribution.keys())
if len(keys) == 0:
Expand Down Expand Up @@ -1062,7 +1065,7 @@ def validate_torch_distributed_distribution(
)

# Check entry point type
if not entry_point.endswith(".py"):
if entry_point is not None and not entry_point.endswith(".py"):
err_msg += (
"Unsupported entry point type for the distribution torch_distributed.\n"
"Only python programs (*.py) are supported."
Expand All @@ -1082,7 +1085,7 @@ def _is_gpu_instance(instance_type):
bool: Whether or not the instance_type supports GPU
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
if match:
if match[1].startswith("p") or match[1].startswith("g"):
return True
Expand All @@ -1101,7 +1104,7 @@ def _is_trainium_instance(instance_type):
bool: Whether or not the instance_type is a Trainium instance
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
if match and match[1].startswith("trn"):
return True
return False
Expand Down Expand Up @@ -1148,7 +1151,7 @@ def _instance_type_supports_profiler(instance_type):
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
match = re.match(r"^ml[\._]([a-z\d\-]+)\.?\w*$", instance_type)
if match and match[1].startswith("trn"):
return True
return False
Expand All @@ -1174,3 +1177,44 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):
"framework_version or py_version was None, yet image_uri was also None. "
"Either specify both framework_version and py_version, or specify image_uri."
)


def create_image_uri(
region,
framework,
instance_type,
framework_version,
py_version=None,
account=None, # pylint: disable=W0613
accelerator_type=None,
optimized_families=None, # pylint: disable=W0613
):
"""Deprecated method. Please use sagemaker.image_uris.retrieve().
Args:
region (str): AWS region where the image is uploaded.
framework (str): framework used by the image.
instance_type (str): SageMaker instance type. Used to determine device
type (cpu/gpu/family-specific optimized).
framework_version (str): The version of the framework.
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
If not specified, image uri will not include a python component.
account (str): AWS account that contains the image. (default:
'520713654638')
accelerator_type (str): SageMaker Elastic Inference accelerator type.
optimized_families (str): Deprecated. A no-op argument.
Returns:
the image uri
"""
from sagemaker.core import image_uris

renamed_warning("The method create_image_uri")
return image_uris.retrieve(
framework=framework,
region=region,
version=framework_version,
py_version=py_version,
instance_type=instance_type,
accelerator_type=accelerator_type,
)
66 changes: 66 additions & 0 deletions sagemaker-core/src/sagemaker/core/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,69 @@
import warnings
import six
from six.moves import urllib
import re
from pathlib import Path
from urllib.parse import urlparse

def _sanitize_git_url(repo_url):
"""Sanitize Git repository URL to prevent URL injection attacks.

Args:
repo_url (str): The Git repository URL to sanitize

Returns:
str: The sanitized URL

Raises:
ValueError: If the URL contains suspicious patterns that could indicate injection
"""
at_count = repo_url.count("@")

if repo_url.startswith("git@"):
# git@ format requires exactly one @
if at_count != 1:
raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol")
elif repo_url.startswith("ssh://"):
# ssh:// format can have 0 or 1 @ symbols
if at_count > 1:
raise ValueError("Invalid SSH URL format: multiple @ symbols detected")
elif repo_url.startswith("https://") or repo_url.startswith("http://"):
# HTTPS format allows 0 or 1 @ symbols
if at_count > 1:
raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected")

# Check for invalid characters in the URL before parsing
# These characters should not appear in legitimate URLs
invalid_chars = ["<", ">", "[", "]", "{", "}", "\\", "^", "`", "|"]
for char in invalid_chars:
if char in repo_url:
raise ValueError("Invalid characters in hostname")

try:
parsed = urlparse(repo_url)

# Check for suspicious characters in hostname that could indicate injection
if parsed.hostname:
# Check for URL-encoded characters that might be used for obfuscation
suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, :
for pattern in suspicious_patterns:
if pattern in parsed.hostname.lower():
raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}")

# Validate that the hostname looks legitimate
if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname):
raise ValueError("Invalid characters in hostname")

except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Failed to parse URL: {str(e)}")
else:
raise ValueError(
"Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed"
)

return repo_url

def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
"""Git clone repo containing the training code and serving code.
Expand Down Expand Up @@ -87,6 +149,10 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
if entry_point is None:
raise ValueError("Please provide an entry point.")
_validate_git_config(git_config)

# SECURITY: Sanitize the repository URL to prevent injection attacks
git_config["repo"] = _sanitize_git_url(git_config["repo"])

dest_dir = tempfile.mkdtemp()
_generate_and_run_clone_command(git_config, dest_dir)

Expand Down
28 changes: 20 additions & 8 deletions sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,16 +330,16 @@ def get_caller_identity_arn(self):
user_profile_name = metadata.get("UserProfileName")
execution_role_arn = metadata.get("ExecutionRoleArn")
try:
# find execution role from the metadata file if present
if execution_role_arn is not None:
return execution_role_arn

if domain_id is None:
instance_desc = self.sagemaker_client.describe_notebook_instance(
NotebookInstanceName=instance_name
)
return instance_desc["RoleArn"]

# find execution role from the metadata file if present
if execution_role_arn is not None:
return execution_role_arn

user_profile_desc = self.sagemaker_client.describe_user_profile(
DomainId=domain_id, UserProfileName=user_profile_name
)
Expand Down Expand Up @@ -666,9 +666,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket

"""
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id,
)
else:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand Down Expand Up @@ -699,7 +706,12 @@ def general_bucket_check_if_user_has_permission(
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
try:
s3.meta.client.head_bucket(Bucket=bucket_name)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name, Prefix=self.default_bucket_prefix
)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand Down
29 changes: 0 additions & 29 deletions sagemaker-core/src/sagemaker/core/huggingface/__init__.py

This file was deleted.

Loading
Loading