diff --git a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py index f4d95f5412..9cb0c7aee4 100644 --- a/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py +++ b/sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py @@ -94,6 +94,50 @@ def from_dependency_file_path(dependency_file_path): class RuntimeEnvironmentManager: """Runtime Environment Manager class to manage runtime environment.""" + def _validate_path(self, path: str) -> str: + """Validate and sanitize file path to prevent path traversal attacks. + + Args: + path (str): The file path to validate + + Returns: + str: The validated absolute path + + Raises: + ValueError: If the path is invalid or contains suspicious patterns + """ + if not path: + raise ValueError("Path cannot be empty") + + # Get absolute path to prevent path traversal + abs_path = os.path.abspath(path) + + # Check for null bytes (common in path traversal attacks) + if '\x00' in path: + raise ValueError(f"Invalid path contains null byte: {path}") + + return abs_path + + def _validate_env_name(self, env_name: str) -> None: + """Validate conda environment name to prevent command injection. + + Args: + env_name (str): The environment name to validate + + Raises: + ValueError: If the environment name contains invalid characters + """ + if not env_name: + raise ValueError("Environment name cannot be empty") + + # Allow only alphanumeric, underscore, and hyphen + import re + if not re.match(r'^[a-zA-Z0-9_-]+$', env_name): + raise ValueError( + f"Invalid environment name '{env_name}'. " + "Only alphanumeric characters, underscores, and hyphens are allowed." + ) + def snapshot(self, dependencies: str = None) -> str: """Creates snapshot of the user's environment @@ -252,39 +296,50 @@ def _is_file_exists(self, dependencies): def _install_requirements_txt(self, local_path, python_executable): """Install requirements.txt file""" - cmd = f"{python_executable} -m pip install -r {local_path} -U" - logger.info("Running command: '%s' in the dir: '%s' ", cmd, os.getcwd()) + # Validate path to prevent command injection + validated_path = self._validate_path(local_path) + cmd = [python_executable, "-m", "pip", "install", "-r", validated_path, "-U"] + logger.info("Running command: '%s' in the dir: '%s' ", " ".join(cmd), os.getcwd()) _run_shell_cmd(cmd) - logger.info("Command %s ran successfully", cmd) + logger.info("Command %s ran successfully", " ".join(cmd)) def _create_conda_env(self, env_name, local_path): """Create conda env using conda yml file""" + # Validate inputs to prevent command injection + self._validate_env_name(env_name) + validated_path = self._validate_path(local_path) - cmd = f"{self._get_conda_exe()} env create -n {env_name} --file {local_path}" - logger.info("Creating conda environment %s using: %s.", env_name, cmd) + cmd = [self._get_conda_exe(), "env", "create", "-n", env_name, "--file", validated_path] + logger.info("Creating conda environment %s using: %s.", env_name, " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Conda environment %s created successfully.", env_name) def _install_req_txt_in_conda_env(self, env_name, local_path): """Install requirements.txt in the given conda environment""" + # Validate inputs to prevent command injection + self._validate_env_name(env_name) + validated_path = self._validate_path(local_path) - cmd = f"{self._get_conda_exe()} run -n {env_name} pip install -r {local_path} -U" - logger.info("Activating conda env and installing requirements: %s", cmd) + cmd = [self._get_conda_exe(), "run", "-n", env_name, "pip", "install", "-r", validated_path, "-U"] + logger.info("Activating conda env and installing requirements: %s", " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Requirements installed successfully in conda env %s", env_name) def _update_conda_env(self, env_name, local_path): """Update conda env using conda yml file""" + # Validate inputs to prevent command injection + self._validate_env_name(env_name) + validated_path = self._validate_path(local_path) - cmd = f"{self._get_conda_exe()} env update -n {env_name} --file {local_path}" - logger.info("Updating conda env: %s", cmd) + cmd = [self._get_conda_exe(), "env", "update", "-n", env_name, "--file", validated_path] + logger.info("Updating conda env: %s", " ".join(cmd)) _run_shell_cmd(cmd) logger.info("Conda env %s updated succesfully", env_name) def _export_conda_env_from_prefix(self, prefix, local_path): """Export the conda env to a conda yml file""" - cmd = f"{self._get_conda_exe()} env export -p {prefix} --no-builds > {local_path}" + cmd = [self._get_conda_exe(), "env", "export", "-p", prefix, "--no-builds", ">", local_path] logger.info("Exporting conda environment: %s", cmd) _run_shell_cmd(cmd) logger.info("Conda environment %s exported successfully", prefix) @@ -402,19 +457,26 @@ def _run_pre_execution_command_script(script_path: str): return return_code, error_logs -def _run_shell_cmd(cmd: str): +def _run_shell_cmd(cmd: list): """This method runs a given shell command using subprocess - Raises RuntimeEnvironmentError if the command fails + Args: + cmd (list): Command and arguments as a list (e.g., ['pip', 'install', '-r', 'requirements.txt']) + + Raises: + RuntimeEnvironmentError: If the command fails + ValueError: If cmd is not a list """ + if not isinstance(cmd, list): + raise ValueError("Command must be a list of arguments for security reasons") - process = subprocess.Popen((cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) _log_output(process) error_logs = _log_error(process) return_code = process.wait() if return_code: - error_message = f"Encountered error while running command '{cmd}'. Reason: {error_logs}" + error_message = f"Encountered error while running command '{' '.join(cmd)}'. Reason: {error_logs}" raise RuntimeEnvironmentError(error_message) diff --git a/sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py b/sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py index 464e46db6d..78f22671dd 100644 --- a/sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py +++ b/sagemaker-train/tests/unit/train/remote_function/test_runtime_environment_manager.py @@ -490,7 +490,7 @@ def test_runs_command_successfully(self, mock_popen, mock_log_output, mock_log_e mock_popen.return_value = mock_process mock_log_error.return_value = "" - _run_shell_cmd("echo test") + _run_shell_cmd(["echo", "test"]) mock_popen.assert_called_once() @@ -505,7 +505,7 @@ def test_runs_command_raises_error_on_failure(self, mock_popen, mock_log_output, mock_log_error.return_value = "Error message" with pytest.raises(RuntimeEnvironmentError): - _run_shell_cmd("false") + _run_shell_cmd(["false"]) class TestLogOutput: