diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 98aea2be5..68d631a9c 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -368,6 +368,8 @@ def _filter_batch_requests( return filtered_requests + + class BasePyTreeCheckpointHandler( async_checkpoint_handler.DeferredPathAsyncCheckpointHandler ): diff --git a/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py b/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py index 4b1cdc7f2..fb1477e1f 100644 --- a/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py +++ b/checkpoint/orbax/checkpoint/_src/logging/event_tracking.py @@ -229,3 +229,5 @@ def record_completion(self, duration_secs: float): duration_secs, storage_type=self._storage_type, ) + +