Skip to content

Commit 708c7b2

Browse files
authored
Bug fix for hmac key (#5348)
* Bug fix for hmac key * Use correct sagemaker-core version * Fixing syntax error with unit test file * Fixing codestyle issues * More codestyle fixes * Additional codestyle fixes * codestyle fixes * Removing unused imports
1 parent c9966d2 commit 708c7b2

File tree

19 files changed

+83
-281
lines changed

19 files changed

+83
-281
lines changed

src/sagemaker/feature_store/feature_processor/_config_uploader.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,6 @@ def _prepare_and_upload_callable(
120120
stored_function = StoredFunction(
121121
sagemaker_session=sagemaker_session,
122122
s3_base_uri=s3_base_uri,
123-
hmac_key=self.remote_decorator_config.environment_variables[
124-
"REMOTE_FUNCTION_SECRET_KEY"
125-
],
126123
s3_kms_key=self.remote_decorator_config.s3_kms_key,
127124
)
128125
stored_function.save(func)

src/sagemaker/remote_function/client.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,6 @@ def wrapper(*args, **kwargs):
362362
s3_uri=s3_path_join(
363363
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
364364
),
365-
hmac_key=job.hmac_key,
366365
)
367366
except ServiceError as serr:
368367
chained_e = serr.__cause__
@@ -399,7 +398,6 @@ def wrapper(*args, **kwargs):
399398
return serialization.deserialize_obj_from_s3(
400399
sagemaker_session=job_settings.sagemaker_session,
401400
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
402-
hmac_key=job.hmac_key,
403401
)
404402

405403
if job.describe()["TrainingJobStatus"] == "Stopped":
@@ -979,7 +977,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
979977
job_return = serialization.deserialize_obj_from_s3(
980978
sagemaker_session=sagemaker_session,
981979
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
982-
hmac_key=job.hmac_key,
983980
)
984981
except DeserializationError as e:
985982
client_exception = e
@@ -991,7 +988,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
991988
job_exception = serialization.deserialize_exception_from_s3(
992989
sagemaker_session=sagemaker_session,
993990
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
994-
hmac_key=job.hmac_key,
995991
)
996992
except ServiceError as serr:
997993
chained_e = serr.__cause__
@@ -1081,7 +1077,6 @@ def result(self, timeout: float = None) -> Any:
10811077
self._return = serialization.deserialize_obj_from_s3(
10821078
sagemaker_session=self._job.sagemaker_session,
10831079
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
1084-
hmac_key=self._job.hmac_key,
10851080
)
10861081
self._state = _FINISHED
10871082
return self._return
@@ -1090,7 +1085,6 @@ def result(self, timeout: float = None) -> Any:
10901085
self._exception = serialization.deserialize_exception_from_s3(
10911086
sagemaker_session=self._job.sagemaker_session,
10921087
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
1093-
hmac_key=self._job.hmac_key,
10941088
)
10951089
except ServiceError as serr:
10961090
chained_e = serr.__cause__

src/sagemaker/remote_function/core/pipeline_variables.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ class _DelayedReturnResolver:
164164
def __init__(
165165
self,
166166
delayed_returns: List[_DelayedReturn],
167-
hmac_key: str,
168167
properties_resolver: _PropertiesResolver,
169168
parameter_resolver: _ParameterResolver,
170169
execution_variable_resolver: _ExecutionVariableResolver,
@@ -175,7 +174,6 @@ def __init__(
175174
176175
Args:
177176
delayed_returns: list of delayed returns to resolve.
178-
hmac_key: key used to encrypt serialized and deserialized function and arguments.
179177
properties_resolver: resolver used to resolve step properties.
180178
parameter_resolver: resolver used to pipeline parameters.
181179
execution_variable_resolver: resolver used to resolve execution variables.
@@ -197,7 +195,6 @@ def deserialization_task(uri):
197195
return uri, deserialize_obj_from_s3(
198196
sagemaker_session=settings["sagemaker_session"],
199197
s3_uri=uri,
200-
hmac_key=hmac_key,
201198
)
202199

203200
with ThreadPoolExecutor() as executor:
@@ -247,7 +244,6 @@ def resolve_pipeline_variables(
247244
context: Context,
248245
func_args: Tuple,
249246
func_kwargs: Dict,
250-
hmac_key: str,
251247
s3_base_uri: str,
252248
**settings,
253249
):
@@ -257,7 +253,6 @@ def resolve_pipeline_variables(
257253
context: context for the execution.
258254
func_args: function args.
259255
func_kwargs: function kwargs.
260-
hmac_key: key used to encrypt serialized and deserialized function and arguments.
261256
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
262257
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
263258
**settings: settings to pass to the deserialization function.
@@ -280,7 +275,6 @@ def resolve_pipeline_variables(
280275
properties_resolver = _PropertiesResolver(context)
281276
delayed_return_resolver = _DelayedReturnResolver(
282277
delayed_returns=delayed_returns,
283-
hmac_key=hmac_key,
284278
properties_resolver=properties_resolver,
285279
parameter_resolver=parameter_resolver,
286280
execution_variable_resolver=execution_variable_resolver,

src/sagemaker/remote_function/core/serialization.py

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,14 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
152152

153153
# TODO: use dask serializer in case dask distributed is installed in users' environment.
154154
def serialize_func_to_s3(
155-
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
155+
func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
156156
):
157157
"""Serializes function and uploads it to S3.
158158
159159
Args:
160160
sagemaker_session (sagemaker.session.Session):
161161
The underlying Boto3 session which AWS service calls are delegated to.
162162
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
163-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
164163
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
165164
func: function to be serialized and persisted
166165
Raises:
@@ -169,14 +168,13 @@ def serialize_func_to_s3(
169168

170169
_upload_payload_and_metadata_to_s3(
171170
bytes_to_upload=CloudpickleSerializer.serialize(func),
172-
hmac_key=hmac_key,
173171
s3_uri=s3_uri,
174172
sagemaker_session=sagemaker_session,
175173
s3_kms_key=s3_kms_key,
176174
)
177175

178176

179-
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
177+
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable:
180178
"""Downloads from S3 and then deserializes data objects.
181179
182180
This method downloads the serialized training job outputs to a temporary directory and
@@ -186,7 +184,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
186184
sagemaker_session (sagemaker.session.Session):
187185
The underlying sagemaker session which AWS service calls are delegated to.
188186
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
189-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
190187
Returns :
191188
The deserialized function.
192189
Raises:
@@ -198,32 +195,26 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
198195

199196
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
200197

201-
_perform_integrity_check(
202-
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
203-
)
198+
_perform_integrity_check(expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize)
204199

205200
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
206201

207202

208-
def serialize_obj_to_s3(
209-
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
210-
):
203+
def serialize_obj_to_s3(obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None):
211204
"""Serializes data object and uploads it to S3.
212205
213206
Args:
214207
sagemaker_session (sagemaker.session.Session):
215208
The underlying Boto3 session which AWS service calls are delegated to.
216209
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
217210
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
218-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
219211
obj: object to be serialized and persisted
220212
Raises:
221213
SerializationError: when fail to serialize object to bytes.
222214
"""
223215

224216
_upload_payload_and_metadata_to_s3(
225217
bytes_to_upload=CloudpickleSerializer.serialize(obj),
226-
hmac_key=hmac_key,
227218
s3_uri=s3_uri,
228219
sagemaker_session=sagemaker_session,
229220
s3_kms_key=s3_kms_key,
@@ -270,14 +261,13 @@ def json_serialize_obj_to_s3(
270261
)
271262

272263

273-
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
264+
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
274265
"""Downloads from S3 and then deserializes data objects.
275266
276267
Args:
277268
sagemaker_session (sagemaker.session.Session):
278269
The underlying sagemaker session which AWS service calls are delegated to.
279270
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
280-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
281271
Returns :
282272
Deserialized python objects.
283273
Raises:
@@ -290,15 +280,13 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
290280

291281
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
292282

293-
_perform_integrity_check(
294-
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
295-
)
283+
_perform_integrity_check(expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize)
296284

297285
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
298286

299287

300288
def serialize_exception_to_s3(
301-
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
289+
exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
302290
):
303291
"""Serializes exception with traceback and uploads it to S3.
304292
@@ -307,7 +295,6 @@ def serialize_exception_to_s3(
307295
The underlying Boto3 session which AWS service calls are delegated to.
308296
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
309297
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
310-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
311298
exc: Exception to be serialized and persisted
312299
Raises:
313300
SerializationError: when fail to serialize object to bytes.
@@ -316,7 +303,6 @@ def serialize_exception_to_s3(
316303

317304
_upload_payload_and_metadata_to_s3(
318305
bytes_to_upload=CloudpickleSerializer.serialize(exc),
319-
hmac_key=hmac_key,
320306
s3_uri=s3_uri,
321307
sagemaker_session=sagemaker_session,
322308
s3_kms_key=s3_kms_key,
@@ -325,7 +311,6 @@ def serialize_exception_to_s3(
325311

326312
def _upload_payload_and_metadata_to_s3(
327313
bytes_to_upload: Union[bytes, io.BytesIO],
328-
hmac_key: str,
329314
s3_uri: str,
330315
sagemaker_session: Session,
331316
s3_kms_key,
@@ -334,15 +319,14 @@ def _upload_payload_and_metadata_to_s3(
334319
335320
Args:
336321
bytes_to_upload (bytes): Serialized bytes to upload.
337-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
338322
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
339323
sagemaker_session (sagemaker.session.Session):
340324
The underlying Boto3 session which AWS service calls are delegated to.
341325
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
342326
"""
343327
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
344328

345-
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
329+
sha256_hash = _compute_hash(bytes_to_upload)
346330

347331
_upload_bytes_to_s3(
348332
_MetaData(sha256_hash).to_json(),
@@ -352,14 +336,13 @@ def _upload_payload_and_metadata_to_s3(
352336
)
353337

354338

355-
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
339+
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
356340
"""Downloads from S3 and then deserializes exception.
357341
358342
Args:
359343
sagemaker_session (sagemaker.session.Session):
360344
The underlying sagemaker session which AWS service calls are delegated to.
361345
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
362-
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
363346
Returns :
364347
Deserialized exception with traceback.
365348
Raises:
@@ -372,9 +355,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
372355

373356
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
374357

375-
_perform_integrity_check(
376-
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
377-
)
358+
_perform_integrity_check(expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize)
378359

379360
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
380361

@@ -399,18 +380,18 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
399380
) from e
400381

401382

402-
def _compute_hash(buffer: bytes, secret_key: str) -> str:
403-
"""Compute the hmac-sha256 hash"""
404-
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
383+
def _compute_hash(buffer: bytes) -> str:
384+
"""Compute the sha256 hash"""
385+
return hashlib.sha256(buffer).hexdigest()
405386

406387

407-
def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
388+
def _perform_integrity_check(expected_hash_value: str, buffer: bytes):
408389
"""Performs integrity checks for serialized code/arguments uploaded to s3.
409390
410391
Verifies whether the hash read from s3 matches the hash calculated
411392
during remote function execution.
412393
"""
413-
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
394+
actual_hash_value = _compute_hash(buffer=buffer)
414395
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
415396
raise DeserializationError(
416397
"Integrity check for the serialized function or data failed. "

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def __init__(
5252
self,
5353
sagemaker_session: Session,
5454
s3_base_uri: str,
55-
hmac_key: str,
5655
s3_kms_key: str = None,
5756
context: Context = Context(),
5857
):
@@ -63,13 +62,11 @@ def __init__(
6362
AWS service calls are delegated to.
6463
s3_base_uri: the base uri to which serialized artifacts will be uploaded.
6564
s3_kms_key: KMS key used to encrypt artifacts uploaded to S3.
66-
hmac_key: Key used to encrypt serialized and deserialized function and arguments.
6765
context: Build or run context of a pipeline step.
6866
"""
6967
self.sagemaker_session = sagemaker_session
7068
self.s3_base_uri = s3_base_uri
7169
self.s3_kms_key = s3_kms_key
72-
self.hmac_key = hmac_key
7370
self.context = context
7471

7572
self.func_upload_path = s3_path_join(
@@ -98,7 +95,6 @@ def save(self, func, *args, **kwargs):
9895
sagemaker_session=self.sagemaker_session,
9996
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
10097
s3_kms_key=self.s3_kms_key,
101-
hmac_key=self.hmac_key,
10298
)
10399

104100
logger.info(
@@ -110,7 +106,6 @@ def save(self, func, *args, **kwargs):
110106
obj=(args, kwargs),
111107
sagemaker_session=self.sagemaker_session,
112108
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
113-
hmac_key=self.hmac_key,
114109
s3_kms_key=self.s3_kms_key,
115110
)
116111

@@ -128,7 +123,6 @@ def save_pipeline_step_function(self, serialized_data):
128123
)
129124
serialization._upload_payload_and_metadata_to_s3(
130125
bytes_to_upload=serialized_data.func,
131-
hmac_key=self.hmac_key,
132126
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
133127
sagemaker_session=self.sagemaker_session,
134128
s3_kms_key=self.s3_kms_key,
@@ -140,7 +134,6 @@ def save_pipeline_step_function(self, serialized_data):
140134
)
141135
serialization._upload_payload_and_metadata_to_s3(
142136
bytes_to_upload=serialized_data.args,
143-
hmac_key=self.hmac_key,
144137
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
145138
sagemaker_session=self.sagemaker_session,
146139
s3_kms_key=self.s3_kms_key,
@@ -156,7 +149,6 @@ def load_and_invoke(self) -> Any:
156149
func = serialization.deserialize_func_from_s3(
157150
sagemaker_session=self.sagemaker_session,
158151
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
159-
hmac_key=self.hmac_key,
160152
)
161153

162154
logger.info(
@@ -166,15 +158,13 @@ def load_and_invoke(self) -> Any:
166158
args, kwargs = serialization.deserialize_obj_from_s3(
167159
sagemaker_session=self.sagemaker_session,
168160
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
169-
hmac_key=self.hmac_key,
170161
)
171162

172163
logger.info("Resolving pipeline variables")
173164
resolved_args, resolved_kwargs = resolve_pipeline_variables(
174165
self.context,
175166
args,
176167
kwargs,
177-
hmac_key=self.hmac_key,
178168
s3_base_uri=self.s3_base_uri,
179169
sagemaker_session=self.sagemaker_session,
180170
)
@@ -190,7 +180,6 @@ def load_and_invoke(self) -> Any:
190180
obj=result,
191181
sagemaker_session=self.sagemaker_session,
192182
s3_uri=s3_path_join(self.results_upload_path, RESULTS_FOLDER),
193-
hmac_key=self.hmac_key,
194183
s3_kms_key=self.s3_kms_key,
195184
)
196185

0 commit comments

Comments
 (0)