From 891883c1c3774b91e84a990c22215e0b044dd855 Mon Sep 17 00:00:00 2001 From: Angel Mau Date: Fri, 29 May 2026 12:21:06 -0700 Subject: [PATCH] #v1 Refactor Context usage to reflect new namespace pattern. PiperOrigin-RevId: 923542628 --- .../_src/testing/benchmarks/v1/benchmark.py | 44 +++---- .../experimental/v1/_src/context/context.py | 21 +++- .../experimental/v1/_src/context/options.py | 16 +-- .../v1/_src/handlers/json_handler.py | 14 +-- .../v1/_src/handlers/proto_handler.py | 13 +- .../v1/_src/handlers/pytree_handler.py | 13 +- .../v1/_src/handlers/pytree_handler_test.py | 67 +++++----- .../v1/_src/handlers/registration.py | 7 +- .../experimental/v1/_src/layout/registry.py | 3 +- .../safetensors_layout_multiprocess_test.py | 29 ++--- .../v1/_src/loading/layout_loading_test.py | 50 ++++---- .../v1/_src/metadata/loading_test.py | 39 +++--- .../_src/serialization/array_leaf_handler.py | 4 +- .../serialization/array_leaf_handler_test.py | 39 +++--- .../_src/serialization/numpy_leaf_handler.py | 4 +- .../serialization/numpy_leaf_handler_test.py | 23 ++-- .../serialization/options_resolution_test.py | 13 +- .../_src/serialization/scalar_leaf_handler.py | 4 +- .../serialization/scalar_leaf_handler_test.py | 23 ++-- .../_src/serialization/string_leaf_handler.py | 4 +- .../serialization/string_leaf_handler_test.py | 5 +- ...tables_metadata_compatibility_test_base.py | 8 +- .../compatibility/generate_v1_checkpoints.py | 8 +- ...checkpointables_compatibility_test_base.py | 8 +- .../load_pytree_compatibility_test_base.py | 8 +- .../manager_compatibility_test_base.py | 28 ++--- ...pytree_metadata_compatibility_test_base.py | 8 +- .../v1/_src/testing/save_load_test_base.py | 118 +++++++----------- .../experimental/v1/_src/training/README.md | 5 +- .../v1/_src/training/checkpointer.py | 2 +- .../_src/training/checkpointer_test_base.py | 85 +++++-------- .../checkpoint/v1/checkpoint_format.ipynb | 10 +- .../checkpoint/v1/checkpointing_pytrees.ipynb | 54 +++----- docs/guides/checkpoint/v1/model_surgery.ipynb | 10 +- .../guides/checkpoint/v1/partial_saving.ipynb | 9 +- 35 files changed, 339 insertions(+), 457 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py index c1aa1dd58..d171aacbc 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py +++ b/checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py @@ -89,31 +89,27 @@ def is_valid(self) -> bool: @property def context(self) -> ocp.Context: - return ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - storage_options=ocp.options.ArrayOptions.Saving.StorageOptions( - chunk_byte_size=self.chunk_byte_size, - ), - use_ocdbt=self.use_ocdbt, - use_zarr3=self.use_zarr3, - use_replica_parallel=self.use_replica_parallel, - use_compression=self.use_compression, - enable_replica_parallel_separate_folder=self.enable_replica_parallel_separate_folder, - ), - loading=ocp.options.ArrayOptions.Loading( - use_load_and_broadcast=self.use_load_and_broadcast, - ), - ), - memory_options=ocp.options.MemoryOptions( - write_concurrent_bytes=self.save_concurrent_gb * 1024**3 - if self.save_concurrent_gb is not None - else None, - read_concurrent_bytes=self.restore_concurrent_gb * 1024**3 - if self.restore_concurrent_gb is not None - else None, - ), + ctx = ocp.Context() + ctx.array.saving.storage_options.chunk_byte_size = self.chunk_byte_size + ctx.array.saving.use_ocdbt = self.use_ocdbt + ctx.array.saving.use_zarr3 = self.use_zarr3 + ctx.array.saving.use_replica_parallel = self.use_replica_parallel + ctx.array.saving.use_compression = self.use_compression + ctx.array.saving.enable_replica_parallel_separate_folder = ( + self.enable_replica_parallel_separate_folder ) + ctx.array.loading.use_load_and_broadcast = self.use_load_and_broadcast + ctx.memory.write_concurrent_bytes = ( + self.save_concurrent_gb * 1024**3 + if self.save_concurrent_gb is not None + else None + ) + ctx.memory.read_concurrent_bytes = ( + self.restore_concurrent_gb * 1024**3 + if self.restore_concurrent_gb is not None + else None + ) + return ctx def clear_pytree(pytree: Any) -> Any: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py index 62815df98..b9d6c5afd 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py @@ -36,7 +36,26 @@ def get_context(default: Context | None = None) -> Context: - """Returns the current `Context` or `default` or `Context()` if not set.""" + """Returns the currently active `Context`, or a default if no context is active. + + If called within a `with ocp.Context(...)` block, this function returns the + `Context` object associated with that block (the active context). + + If called outside of any `with` block, this function returns `default` + if it is provided. If `default` is not provided or `None`, it returns a + new `Context` instance initialized with default options. + + Note: If a context is active, the `default` parameter is ignored, and the + active context is always returned. To ensure that an explicitly provided + context takes precedence over any active context, use the pattern: + `ctx = explicit_context if explicit_context is not None else get_context()`. + + Args: + default: A `Context` object to return if no context is active. + + Returns: + The active `Context` or a default `Context`. + """ default = default or Context() return _CONTEXT.get(default) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index ada3b13d4..3d10b6666 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -19,7 +19,7 @@ import contextvars import dataclasses import enum -from typing import Any, Callable, Protocol, Type +from typing import Any, Callable, Protocol from etils import epath import numpy as np @@ -29,7 +29,6 @@ from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.serialization import pathways_types from orbax.checkpoint.experimental.v1._src.handlers import registration -from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types from orbax.checkpoint.experimental.v1._src.tree import types as tree_types @@ -495,19 +494,6 @@ class CheckpointablesOptions(_ActiveContextGuard): ) ) - @classmethod - def create_with_handlers( - cls, - *handlers: Type[handler_types.CheckpointableHandler], - **named_handlers: Type[handler_types.CheckpointableHandler], - ) -> CheckpointablesOptions: - registry = registration.local_registry(include_global_registry=True) - for handler in handlers: - registry.add(handler, checkpointable_name=None) - for name, handler in named_handlers.items(): - registry.add(handler, checkpointable_name=name) - return cls(registry=registry) - @dataclasses.dataclass(kw_only=True) class PathwaysOptions(_ActiveContextGuard): diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py index 8e3e430a0..b1becb00e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py @@ -72,14 +72,14 @@ class JsonHandler(CheckpointableHandler[JsonType, None]): config = {'learning_rate': 0.01, 'batch_size': 32} - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - experiment_config=ocp.handlers.JsonHandler( - filename='experiment_config.json' - ) - ) + registry = ocp.handlers.local_registry() + registry.add( + ocp.handlers.JsonHandler, + checkpointable_name='experiment_config', ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(path, dict(experiment_config=config)) Attributes: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py index 3d8feb845..c178c1ace 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py @@ -63,14 +63,13 @@ class ProtoHandler( # Assuming MyProtoMessage is your compiled protobuf class my_proto_msg = MyProtoMessage(config_field="value") - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - proto_config=ocp.handlers.ProtoHandler( - filename="model_config.pbtxt" - ) - ) + registry = ocp.handlers.local_registry() + registry.add( + ocp.handlers.ProtoHandler, checkpointable_name="proto_config" ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(path, dict(proto_config=my_proto_msg)) Attributes: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index 703e38922..a8f26f4f0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -272,12 +272,13 @@ class PyTreeHandler(CheckpointableHandler[PyTree, PyTree]): state_pytree = {'weights': [1.0, 2.0], 'bias': 0.0} - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - model_state=ocp.handlers.PyTreeHandler() - ) + registry = ocp.handlers.local_registry() + registry.add( + ocp.handlers.PyTreeHandler, checkpointable_name='model_state' ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(path, dict(model_state=state_pytree)) Attributes: @@ -299,7 +300,7 @@ def __init__( ) = None, partial_save_mode: bool = False, ): - context = context_lib.get_context(context) + context = context if context is not None else context_lib.get_context() self._context = context self._multiprocessing_options = context.multiprocessing_options self._partial_save_mode = partial_save_mode diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py index 792bd2b65..bf546f55b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler_test.py @@ -248,8 +248,8 @@ def handler_with_options( options_lib.ArrayOptions.Saving.ScopedStorageOptionsCreator | None ) = None, array_storage_options: ( - options_lib.ArrayOptions.Saving.StorageOptions - ) = options_lib.ArrayOptions.Saving.StorageOptions(), + options_lib.ArrayOptions.Saving.StorageOptions | None + ) = None, save_concurrent_bytes: int | None = None, restore_concurrent_bytes: int | None = None, use_ocdbt: bool = True, @@ -271,37 +271,36 @@ def handler_with_options( ) = None, ): """Registers handlers with OCDBT support and resets when done.""" - context = context_lib.Context( - array_options=options_lib.ArrayOptions( - saving=options_lib.ArrayOptions.Saving( - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - use_compression=use_compression, - ocdbt_target_data_file_size=ocdbt_target_data_file_size, - enable_pinned_host_transfer=enable_pinned_host_transfer, - array_metadata_store=array_metadata_store, - enable_write_sharding_file=enable_write_sharding_file, - use_replica_parallel=not utils.is_pathways_backend(), - storage_options=array_storage_options, - scoped_storage_options_creator=scoped_storage_options_creator, - ), - loading=options_lib.ArrayOptions.Loading( - enable_padding_and_truncation=enable_padding_and_truncation, - ), - ), - memory_options=options_lib.MemoryOptions( - write_concurrent_bytes=save_concurrent_bytes, - read_concurrent_bytes=restore_concurrent_bytes, - ), - pytree_options=options_lib.PyTreeOptions( - saving=options_lib.PyTreeOptions.Saving( - pytree_metadata_options=pytree_metadata_options, - ), - loading=options_lib.PyTreeOptions.Loading( - partial_load=partial_load, - ), - ), + context = context_lib.Context() + + context.array.saving.use_ocdbt = use_ocdbt + context.array.saving.use_zarr3 = use_zarr3 + context.array.saving.use_compression = use_compression + context.array.saving.ocdbt_target_data_file_size = ocdbt_target_data_file_size + context.array.saving.enable_pinned_host_transfer = enable_pinned_host_transfer + context.array.saving.array_metadata_store = array_metadata_store + context.array.saving.enable_write_sharding_file = enable_write_sharding_file + context.array.saving.use_replica_parallel = not utils.is_pathways_backend() + if array_storage_options is not None: + context.array.saving.storage_options.dtype = array_storage_options.dtype + context.array.saving.storage_options.chunk_byte_size = ( + array_storage_options.chunk_byte_size + ) + context.array.saving.storage_options.shard_axes = ( + array_storage_options.shard_axes + ) + context.array.saving.scoped_storage_options_creator = ( + scoped_storage_options_creator ) + context.array.loading.enable_padding_and_truncation = ( + enable_padding_and_truncation + ) + + context.memory.write_concurrent_bytes = save_concurrent_bytes + context.memory.read_concurrent_bytes = restore_concurrent_bytes + + context.pytree.saving.pytree_metadata_options = pytree_metadata_options + context.pytree.loading.partial_load = partial_load handler = handler_test_utils.create_test_handler( pytree_handler.PyTreeHandler, @@ -309,10 +308,8 @@ def handler_with_options( leaf_handler_registry=leaf_handler_registry, ) - try: + with context: yield handler - finally: - pass class PyTreeHandlerTest( diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py index 321304373..0200c7e5c 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py @@ -39,10 +39,9 @@ # to a new v1 handler class. registry.add(BazHandler, secondary_typestrs=['OldBazHandlerTypestr']) - checkpointables_options = ocp.options.CheckpointablesOptions( - registry=registry - ) - with ocp.Context(checkpointables_options=checkpointables_options): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables(...) Handler resolution for saving/loading follows this logic: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py index 8c83cb258..6d4c4685b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/registry.py @@ -105,7 +105,8 @@ async def get_checkpoint_layout( f"Could not recognize the checkpoint at {path} as a valid" f" {layout_enum.value} checkpoint. If you are trying to load a" " checkpoint that does not conform to the standard Orbax format, use" - " `ocp.Context(layout=...)` to specify the expected checkpoint layout." + " `ctx.checkpoint_layout = ...` to specify the expected checkpoint" + " layout." ) from e diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py index 65cd6ee18..547a7ada4 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/safetensors_layout_multiprocess_test.py @@ -26,7 +26,6 @@ from orbax.checkpoint import test_utils from orbax.checkpoint._src.testing import multiprocess_test from orbax.checkpoint.experimental.v1._src.context import context as context_lib -from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout import safetensors.numpy @@ -183,12 +182,12 @@ async def test_load_without_global_reshard_single_tensor(self): } layout = SafetensorsLayout() - with context_lib.Context( - safetensors_options=options_lib.SafetensorsOptions( - ignore_load_sharding=True - ) - ): - restore_fn = await layout.load(st_path, abstract_state=abstract_state) + ctx = context_lib.Context() + ctx.safetensors.ignore_load_sharding = True + with ctx: + restore_fn = await layout.load( + st_path, abstract_state=abstract_state + ) restored_pytree = await restore_fn restored_tensor = restored_pytree["params.tensor"] @@ -225,11 +224,9 @@ async def test_load_without_global_reshard_multi_tensor(self): } layout = SafetensorsLayout() - with context_lib.Context( - safetensors_options=options_lib.SafetensorsOptions( - ignore_load_sharding=True - ) - ): + ctx = context_lib.Context() + ctx.safetensors.ignore_load_sharding = True + with ctx: restore_fn = await layout.load(st_path, abstract_state=abstract_state) restored_pytree = await restore_fn @@ -337,11 +334,9 @@ async def test_load_without_global_reshard_memory_efficiency(self): tracemalloc.start() - with context_lib.Context( - safetensors_options=options_lib.SafetensorsOptions( - ignore_load_sharding=True - ) - ): + ctx = context_lib.Context() + ctx.safetensors.ignore_load_sharding = True + with ctx: restore_fn = await layout.load( file_path, abstract_state=abstract_pytree, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py index 76f516dbe..de59c8f9e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/loading/layout_loading_test.py @@ -76,9 +76,9 @@ def setUp(self): ) def test_load_safetensors_checkpoint(self): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: pytree = loading.load(self.safetensors_path) self.assertIsInstance(pytree, dict) np.testing.assert_array_equal(pytree['a'], self.object_to_save['a']) @@ -98,7 +98,9 @@ def test_load_orbax_checkpointables_checkpoint(self): ) def test_load_bad_path_orbax_ckpt(self, layout_enum): # User provides a directory of Orbax checkpoints, not specific one. - with context_lib.Context(checkpoint_layout=layout_enum): + ctx = context_lib.Context() + ctx.checkpoint_layout = layout_enum + with ctx: with self.assertRaises(InvalidLayoutError): loading.load( epath.Path(self.test_dir.full_path), @@ -109,7 +111,9 @@ def test_load_bad_path_orbax_ckpt(self, layout_enum): ) def test_load_bad_path_safetensors_ckpt(self, layout_enum): # User provides a empty directory of SafeTensors checkpoints, not a file. - with context_lib.Context(checkpoint_layout=layout_enum): + ctx = context_lib.Context() + ctx.checkpoint_layout = layout_enum + with ctx: with self.assertRaises(InvalidLayoutError): loading.load( epath.Path(self.test_dir_safetensors.full_path), @@ -119,9 +123,9 @@ def test_load_safetensors_ckpt_from_dir(self): safetensors_dir = epath.Path(self.test_dir_safetensors.full_path) safetensors_path = safetensors_dir / 'model.safetensors' np_save_file(self.object_to_save, safetensors_path) - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: pytree = loading.load(safetensors_dir) self.assertIsInstance(pytree, dict) np.testing.assert_array_equal(pytree['a'], self.object_to_save['a']) @@ -179,7 +183,9 @@ async def sleep_and_load(*args, **kwargs): else: directory = self.orbax_pytree_path - with context_lib.Context(checkpoint_layout=layout): + ctx = context_lib.Context() + ctx.checkpoint_layout = layout + with ctx: if layout != options_lib.CheckpointLayout.SAFETENSORS: with self.assertRaises(NotImplementedError): loading.load_async(directory) @@ -196,9 +202,9 @@ async def sleep_and_load(*args, **kwargs): # TODO(b/431045454): Add tests for abstract_checkpointables. def test_load_auto_resolution_mode_orbax(self): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.ORBAX - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.ORBAX + with ctx: loaded_orbax = loading.load( self.orbax_pytree_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, @@ -206,9 +212,9 @@ def test_load_auto_resolution_mode_orbax(self): test_utils.assert_tree_equal(self, self.object_to_save, loaded_orbax) def test_load_auto_resolution_mode_safetensors(self): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: loaded_safe = loading.load( self.safetensors_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, @@ -225,10 +231,10 @@ def test_load_auto_multiple_checkpointables_priority(self): multiple_path = epath.Path(self.test_dir.full_path) / 'multi_checkpoint' saving.save_checkpointables(multiple_path, checkpointables) - # Triggering AUTO loading mode should prioritize resolving state. - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.ORBAX - ): + # Triggering AUTO loading mode should prioritize resolving 'pytree'. + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.ORBAX + with ctx: loaded = loading.load(multiple_path) test_utils.assert_tree_equal( @@ -243,9 +249,9 @@ def test_load_auto_non_pytree_fallback(self): fallback_path = epath.Path(self.test_dir.full_path) / 'fallback_checkpoint' saving.save_checkpointables(fallback_path, custom_checkpointables) - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.ORBAX - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.ORBAX + with ctx: loaded = loading.load( fallback_path, checkpointable_name=checkpoint_layout.AUTO_CHECKPOINTABLE_KEY, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py index ebb470097..22131f9c1 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/metadata/loading_test.py @@ -144,9 +144,9 @@ def test_metadata_safetensors(self): 'y': jax.ShapeDtypeStruct(shape=(3,), dtype=np.int64), } - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ckpt_metadata = ocp.metadata(st_path) self.assertIsInstance(ckpt_metadata, metadata_types.CheckpointMetadata) @@ -163,9 +163,9 @@ def test_metadata_safetensors(self): # Test invalid path with self.assertRaises(ocp.errors.InvalidLayoutError): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ocp.metadata(self.directory) @@ -191,15 +191,12 @@ class CheckpointablesMetadataTest(absltest.TestCase): def setUp(self): super().setUp() self.directory = epath.Path(self.create_tempdir().full_path) / 'ckpt' - checkpointables_options = ( - options_lib.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - context_lib.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = context_lib.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { 'foo': Foo(1, 'foo'), 'bar': Bar(2, 'bar'), @@ -245,9 +242,9 @@ def test_checkpointables_metadata_safetensors(self): 'item2': jax.ShapeDtypeStruct(shape=(1,), dtype=np.int32), } - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ckpt_metadata = ocp.checkpointables_metadata(st_path) self.assertIsInstance(ckpt_metadata, metadata_types.CheckpointMetadata) @@ -266,9 +263,9 @@ def test_checkpointables_metadata_safetensors(self): # Test invalid path with self.assertRaises(ocp.errors.InvalidLayoutError): - with context_lib.Context( - checkpoint_layout=options_lib.CheckpointLayout.SAFETENSORS - ): + ctx = context_lib.Context() + ctx.checkpoint_layout = options_lib.CheckpointLayout.SAFETENSORS + with ctx: ocp.checkpointables_metadata(self.directory) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py index db5633542..d7558ada2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py @@ -197,7 +197,9 @@ def __init__( *, context: context_lib.Context | None = None, ): - self._context = context_lib.get_context(context) + self._context = ( + context if context is not None else context_lib.get_context() + ) self._handler_impl = _create_v0_array_handler( self._context, ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler_test.py index 9e6369f82..3bcbec3f3 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler_test.py @@ -26,7 +26,6 @@ from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.tree import utils as tree_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib -from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.serialization import array_leaf_handler from orbax.checkpoint.experimental.v1._src.serialization import types from orbax.checkpoint.experimental.v1._src.synchronization import multihost @@ -126,27 +125,27 @@ async def _test_simple_checkpoint_impl( self.create_tempdir(f'tmp_{self._testMethodName}').full_path ) - init_context = context_lib.Context( - array_options=options_lib.ArrayOptions( - saving=options_lib.ArrayOptions.Saving( - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - use_replica_parallel=use_replica_parallel, - enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder, - enable_pinned_host_transfer=enable_pinned_host_transfer, - use_compression=use_compression, - min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel, - max_replicas_for_replica_parallel=max_replicas_for_replica_parallel, - ), - loading=options_lib.ArrayOptions.Loading(), - ), - memory_options=options_lib.MemoryOptions( - write_concurrent_bytes=save_concurrent_bytes, - read_concurrent_bytes=load_concurrent_bytes, - ), + context = context_lib.Context() + context.array.saving.use_ocdbt = use_ocdbt + context.array.saving.use_zarr3 = use_zarr3 + context.array.saving.use_replica_parallel = use_replica_parallel + context.array.saving.enable_replica_parallel_separate_folder = ( + enable_replica_parallel_separate_folder + ) + context.array.saving.enable_pinned_host_transfer = ( + enable_pinned_host_transfer + ) + context.array.saving.use_compression = use_compression + context.array.saving.min_slice_bytes_for_replica_parallel = ( + min_slice_bytes_for_replica_parallel + ) + context.array.saving.max_replicas_for_replica_parallel = ( + max_replicas_for_replica_parallel ) + context.memory.write_concurrent_bytes = save_concurrent_bytes + context.memory.read_concurrent_bytes = load_concurrent_bytes - with context_lib.get_context(init_context) as context: + with context: handler = array_leaf_handler.ArrayLeafHandler() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py index d7946d5f4..74552f8e2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py @@ -163,7 +163,9 @@ def __init__( *, context: context_lib.Context | None = None, ): - self._context = context_lib.get_context(context) + self._context = ( + context if context is not None else context_lib.get_context() + ) self._handler_impl = _create_v0_numpy_handler() logging.vlog(1, 'NumpyLeafHandler created.') diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler_test.py index 6b8d579c6..654dd6450 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler_test.py @@ -25,7 +25,6 @@ from orbax.checkpoint._src.serialization import ocdbt_utils from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib -from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.serialization import numpy_leaf_handler from orbax.checkpoint.experimental.v1._src.serialization import types from orbax.checkpoint.experimental.v1._src.synchronization import multihost @@ -108,22 +107,14 @@ async def test_simple_checkpoint( self.create_tempdir(f'tmp_{self._testMethodName}').full_path ) - init_context = context_lib.Context( - array_options=options_lib.ArrayOptions( - saving=options_lib.ArrayOptions.Saving( - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - use_compression=use_compression, - ), - loading=options_lib.ArrayOptions.Loading(), - ), - memory_options=options_lib.MemoryOptions( - write_concurrent_bytes=save_concurrent_bytes, - read_concurrent_bytes=load_concurrent_bytes, - ), - ) + context = context_lib.Context() + context.array.saving.use_ocdbt = use_ocdbt + context.array.saving.use_zarr3 = use_zarr3 + context.array.saving.use_compression = use_compression + context.memory.write_concurrent_bytes = save_concurrent_bytes + context.memory.read_concurrent_bytes = load_concurrent_bytes - with context_lib.get_context(init_context) as context: + with context: handler = numpy_leaf_handler.NumpyLeafHandler() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py index 897b959f9..db1c1970f 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py @@ -95,21 +95,16 @@ def test_resolve_storage_options( shard_axes=(0,), ) - context = context_lib.Context( - array_options=options_lib.ArrayOptions( - saving=options_lib.ArrayOptions.Saving( - storage_options=global_storage, - scoped_storage_options_creator=callback, - ) - ), - ) + ctx = context_lib.Context() + ctx.array.saving.storage_options = global_storage + ctx.array.saving.scoped_storage_options_creator = callback # Dummy param keypath = (jax.tree_util.DictKey(key='foo'),) value = np.ones((2, 2)) resolved_options = options_resolution.resolve_storage_options( - keypath, value, context.array_options.saving + keypath, value, ctx.array.saving ) self.assertEqual(resolved_options, expected_storage_options) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py index b0fad2593..2e6b7dbe9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py @@ -143,7 +143,9 @@ def __init__( *, context: context_lib.Context | None = None, ): - self._context = context_lib.get_context(context) + self._context = ( + context if context is not None else context_lib.get_context() + ) self._handler_impl = _create_v0_scalar_handler() logging.vlog(1, "ScalarLeafHandler created.") diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler_test.py index 72e3f0dff..3e7e8db4b 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler_test.py @@ -25,7 +25,6 @@ from orbax.checkpoint._src.serialization import ocdbt_utils from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib -from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler from orbax.checkpoint.experimental.v1._src.serialization import types from orbax.checkpoint.experimental.v1._src.synchronization import multihost @@ -113,22 +112,14 @@ async def test_simple_checkpoint( self.create_tempdir(f'tmp_{self._testMethodName}').full_path ) - init_context = context_lib.Context( - array_options=options_lib.ArrayOptions( - saving=options_lib.ArrayOptions.Saving( - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - use_compression=use_compression, - ), - loading=options_lib.ArrayOptions.Loading(), - ), - memory_options=options_lib.MemoryOptions( - write_concurrent_bytes=save_concurrent_bytes, - read_concurrent_bytes=load_concurrent_bytes, - ), - ) + context = context_lib.Context() + context.array.saving.use_ocdbt = use_ocdbt + context.array.saving.use_zarr3 = use_zarr3 + context.array.saving.use_compression = use_compression + context.memory.write_concurrent_bytes = save_concurrent_bytes + context.memory.read_concurrent_bytes = load_concurrent_bytes - with context_lib.get_context(init_context) as context: + with context: handler = scalar_leaf_handler.ScalarLeafHandler() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py index aa28754ed..9b67f3364 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler.py @@ -53,7 +53,9 @@ def __init__( Args: context: Context that will be used for this leaf handler. """ - self._context = context_lib.get_context(context) + self._context = ( + context if context is not None else context_lib.get_context() + ) self._filename = '_strings.json' logging.vlog(1, 'StringLeafHandler created.') diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler_test.py index 78085bcda..e2b622777 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/string_leaf_handler_test.py @@ -94,9 +94,8 @@ async def test_simple_checkpoint( self.create_tempdir(f'tmp_{self._testMethodName}').full_path ) - init_context = context_lib.Context() - - with context_lib.get_context(init_context) as context: + context = context_lib.Context() + with context: handler = string_leaf_handler.StringLeafHandler() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py index b46455688..d8e3a61fb 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py @@ -126,11 +126,9 @@ def test_checkpointables_metadata_compatibility( ) ) - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.checkpointables_metadata(path) # If the state checpointable is missing pytree metadata, then we expect diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py index 21c3781f9..870a9eb71 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpoints.py @@ -85,11 +85,9 @@ def generate_v1_checkpoint(path: epath.Path) -> None: registry = registration.local_registry() registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='state') registry.add(ocp.handlers.JsonHandler, checkpointable_name='metadata') - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: ocp.save_checkpointables( path, checkpointables, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py index 2ae9c5396..4229e4788 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py @@ -253,11 +253,9 @@ def test_load_checkpointables_compatibility( else: abstract_checkpointables = None - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.load_checkpointables( path, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py index e7ae95296..be9bdf084 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py @@ -262,11 +262,9 @@ def test_load_compatibility( self.abstract_state if abstract_pytree_provided else None ) - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.load( path, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py index 2238b3aaf..ebe07ff14 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/manager_compatibility_test_base.py @@ -283,11 +283,8 @@ def test_metadata( name_format=name_format, ) registry = self.setup_registry() - context = ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ) + context = ocp.Context() + context.checkpointables.registry = registry self.enter_context(context) checkpointer = Checkpointer(path, step_name_format=name_format) self.enter_context(checkpointer) @@ -337,11 +334,8 @@ def test_checkpointables_metadata( name_format=name_format, ) registry = self.setup_registry() - context = ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ) + context = ocp.Context() + context.checkpointables.registry = registry self.enter_context(context) checkpointer = Checkpointer(path, step_name_format=name_format) self.enter_context(checkpointer) @@ -394,11 +388,8 @@ def test_load_checkpointables( name_format=name_format, ) registry = self.setup_registry() - context = ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ) + context = ocp.Context() + context.checkpointables.registry = registry self.enter_context(context) checkpointer = Checkpointer(path, step_name_format=name_format) self.enter_context(checkpointer) @@ -438,11 +429,8 @@ def test_load( name_format=name_format, ) registry = self.setup_registry() - context = ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ) + context = ocp.Context() + context.checkpointables.registry = registry self.enter_context(context) checkpointer = Checkpointer(path, step_name_format=name_format) self.enter_context(checkpointer) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py index 51427199c..f2f5f5348 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py @@ -193,11 +193,9 @@ def test_metadata_compatibility( is_pytree, ) - with ocp.Context( - checkpointables_options=ocp.options.CheckpointablesOptions( - registry=registry - ) - ): + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: if error_type is None: loaded = ocp.metadata( path, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 0f4bc46e8..b67aa677d 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -208,9 +208,8 @@ async def mock_finalize(self_handler, directory): ) ) - context = ocp.Context( - async_options=ocp.options.AsyncOptions(timeout_secs=timeout_secs) - ) + context = ocp.Context() + context.asynchronous.timeout_secs = timeout_secs self.enter_context(context) start = time.time() @@ -538,13 +537,11 @@ def test_casting(self, original_dtype, save_dtype, load_dtype): dtype=save_dtype ) ) - with ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - scoped_storage_options_creator=scoped_storage_options_creator - ) - ) - ): + ctx = ocp.Context() + ctx.array.saving.scoped_storage_options_creator = ( + scoped_storage_options_creator + ) + with ctx: ocp.save(self.directory, tree) with self.subTest('with_abstract_tree'): @@ -673,15 +670,12 @@ def test_missing_keys(self): ) def test_custom_checkpointables(self): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { 'pytree': self.numpy_pytree, 'foo': Foo(123, 'hi'), @@ -751,14 +745,11 @@ def test_save_checkpointables_ambiguous_resolution(self): 'two': {'c': 3, 'd': 4}, } directory = self.directory - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - one=handler_utils.DictHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.DictHandler, checkpointable_name='one') + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) ocp.save_checkpointables(directory, checkpointables) self.assertTrue((directory / 'one' / 'data.txt').exists()) self.assertFalse((directory / 'two' / 'data.txt').exists()) @@ -802,15 +793,12 @@ def test_abstract_pytree_types(self): test_utils.assert_tree_equal(self, self.pytree, loaded) def test_abstract_checkpointables_types(self): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { 'foo': Foo(123, 'hi'), 'bar': Bar(456, 'bye'), @@ -837,14 +825,11 @@ def test_abstract_checkpointables_types(self): self.assertEqual(checkpointables, loaded) def test_async_directory_creation(self): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) self.enter_context( mock.patch.object( async_utils, '_create_paths', _sleep_and_create_paths @@ -1074,13 +1059,9 @@ def test_partial_restore_omission(self): 'y': self.pytree['y'], } - with ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - loading=ocp.options.PyTreeOptions.Loading( - partial_load=True, - ) - ) - ): + ctx = ocp.Context() + ctx.pytree.loading.partial_load = True + with ctx: loaded = ocp.load(self.directory, reference_pytree) test_utils.assert_tree_equal(self, expected, loaded) @@ -1161,13 +1142,9 @@ def test_load_and_broadcast(self): self.assertEqual( sharding.shard_shape((4, 32)), (4, 32 // partition_count) ) - with ocp.Context( - array_options=ocp.options.ArrayOptions( - loading=ocp.options.ArrayOptions.Loading( - use_load_and_broadcast=True, - ) - ) - ): + ctx = ocp.Context() + ctx.array.loading.use_load_and_broadcast = True + with ctx: ocp.save(self.directory, [arr]) with self.subTest('with_abstract_pytree'): loaded = ocp.load( @@ -1199,15 +1176,9 @@ def test_subchunking(self): } with self.subTest('global_setting'): - with ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - storage_options=ocp.options.ArrayOptions.Saving.StorageOptions( - chunk_byte_size=8, # force divide in two subchunks - ) - ) - ) - ): + ctx = ocp.Context() + ctx.array.saving.storage_options.chunk_byte_size = 8 + with ctx: ocp.save(self.directory / 'global_setting', pytree) metadata = ocp.metadata(self.directory / 'global_setting').metadata for k in pytree: @@ -1225,13 +1196,12 @@ def scoped_storage_options_creator(key, value): return ocp.options.ArrayOptions.Saving.StorageOptions( chunk_byte_size=8, # force divide in 2 subchunks ) - with ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving( - scoped_storage_options_creator=scoped_storage_options_creator - ) - ), - ): + + ctx = ocp.Context() + ctx.array.saving.scoped_storage_options_creator = ( + scoped_storage_options_creator + ) + with ctx: ocp.save(self.directory / 'per_key_setting', pytree) metadata = ocp.metadata(self.directory / 'per_key_setting').metadata self.assertEqual(metadata['a'].shape, (32,)) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/README.md b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/README.md index 39fc57fab..5858f8e40 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/README.md +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/README.md @@ -47,9 +47,8 @@ params, state = restored['params'], restored['state'] import orbax.checkpoint.v1 as ocp # 1. Environment/IO settings go in Context. -context = ocp.Context( - async_options=ocp.options.AsyncOptions(timeout_secs=60), -) +context = ocp.Context() +context.asynchronous.timeout_secs = 60 # 2. Logic & Lifecycle settings passed to Checkpointer constructor. with context: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index 2d8ad0e07..4857e74ed 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -172,7 +172,7 @@ def __init__( checkpoint steps present and checkpoint info properties like `time` and `metrics` are not needed. """ - self._context = context or context_lib.get_context() + self._context = context_lib.Context(context or context_lib.get_context()) default_save_decision_policy = save_decision_policies.AnySavePolicy([ save_decision_policies.InitialSavePolicy(), diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index 3176e520c..d97420c77 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -519,15 +519,12 @@ def test_checkpointables_metadata(self, reinitialize_checkpointer): def test_custom_checkpointables(self): """Test custom checkpointables are saved and loaded.""" - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { STATE_CHECKPOINTABLE_KEY: self.pytree, 'foo': Foo(123, 'hi'), @@ -554,13 +551,12 @@ def test_custom_checkpointables(self): with self.subTest('load_with_free_function'): if multihost.is_pathways_backend(): self.skipTest('Sharding metadata not present in Pathways.') - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - foo=handler_utils.FooHandler, - bar=handler_utils.BarHandler, - ) - ) - with ocp.Context(checkpointables_options=checkpointables_options): + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler, checkpointable_name='foo') + registry.add(handler_utils.BarHandler, checkpointable_name='bar') + ctx = ocp.Context() + ctx.checkpointables.registry = registry + with ctx: loaded = ocp.load_checkpointables(self.directory / '0') self.assertSameElements( loaded.keys(), [STATE_CHECKPOINTABLE_KEY, 'foo', 'bar'] @@ -616,15 +612,12 @@ def test_custom_checkpointables(self): def test_load_with_switched_abstract_checkpointables(self): """Test load with switched abstract checkpointables.""" - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointables = { STATE_CHECKPOINTABLE_KEY: self.pytree, 'foo': Foo(123, 'hi'), @@ -645,15 +638,12 @@ def test_load_with_switched_abstract_checkpointables(self): def test_different_custom_checkpointables(self): """Test different custom checkpointables are saved and loaded.""" - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - handler_utils.FooHandler, - handler_utils.BarHandler, - ) - ) - self.enter_context( - ocp.Context(checkpointables_options=checkpointables_options) - ) + registry = ocp.handlers.local_registry() + registry.add(handler_utils.FooHandler) + registry.add(handler_utils.BarHandler) + ctx = ocp.Context() + ctx.checkpointables.registry = registry + self.enter_context(ctx) checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) self.save_checkpointables(checkpointer, 0, {'foo': Foo(123, 'hi')}) @@ -815,12 +805,9 @@ def test_preservation_metrics(self, policy, expected_steps): checkpointer.close() def test_gcs_deletion_options(self): - deletion_options = ocp.options.DeletionOptions( - gcs_deletion_options=ocp.options.DeletionOptions.GcsDeletionOptions( - todelete_full_path='trash' - ) - ) - with ocp.Context(deletion_options=deletion_options): + ctx = ocp.Context() + ctx.deletion.gcs_deletion_options.todelete_full_path = 'trash' + with ctx: checkpointer = Checkpointer(self.directory) self.assertEqual( checkpointer._manager._options.todelete_full_path, 'trash' @@ -828,14 +815,9 @@ def test_gcs_deletion_options(self): def test_context_constructor_override(self): - ctx1 = ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving(use_ocdbt=False) - ), - pytree_options=ocp.options.PyTreeOptions( - loading=ocp.options.PyTreeOptions.Loading(partial_load=True) - ), - ) + ctx1 = ocp.Context() + ctx1.array.saving.use_ocdbt = False + ctx1.pytree.loading.partial_load = True checkpointer = Checkpointer(self.directory, context=ctx1) self.enter_context(checkpointer) self.save(checkpointer, 0, self.pytree) @@ -861,11 +843,8 @@ def test_context_constructor_override(self): with self.subTest('local_context_override'): # Override with local context setting use_ocdbt=True - ctx2 = ocp.Context( - array_options=ocp.options.ArrayOptions( - saving=ocp.options.ArrayOptions.Saving(use_ocdbt=True) - ) - ) + ctx2 = ocp.Context() + ctx2.array.saving.use_ocdbt = True with ctx2: self.save(checkpointer, 1, self.pytree) diff --git a/docs/guides/checkpoint/v1/checkpoint_format.ipynb b/docs/guides/checkpoint/v1/checkpoint_format.ipynb index 96003a724..e5851d0ed 100644 --- a/docs/guides/checkpoint/v1/checkpoint_format.ipynb +++ b/docs/guides/checkpoint/v1/checkpoint_format.ipynb @@ -301,11 +301,11 @@ "# Note that the example would work even without the extra step of forcing\n", "# `extra_properties` to be handled by `JsonHandler`. We just want to ensure it\n", "# gets JSON-encoded for demonstration purposes.\n", - "with ocp.Context(\n", - " checkpointables_options=ocp.options.CheckpointablesOptions.create_with_handlers(\n", - " extra_properties=ocp.handlers.JsonHandler\n", - " )\n", - "):\n", + "registry = ocp.handlers.local_registry()\n", + "registry.add(ocp.handlers.JsonHandler, checkpointable_name='extra_properties')\n", + "ctx = ocp.Context()\n", + "ctx.checkpointables.registry = registry\n", + "with ctx:\n", " ocp.save_checkpointables(\n", " directory / 'ckpt-0',\n", " dict(pytree=pytree, extra_properties={'foo': 'bar'}),\n", diff --git a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb index 9a429295a..165de5b4c 100644 --- a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb +++ b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb @@ -831,13 +831,10 @@ " # Return None to use global default storage_options for other leaves\n", " return None\n", "\n", - "with ocp.Context(\n", - " array_options=ocp.options.ArrayOptions(\n", - " saving=ocp.options.ArrayOptions.Saving(\n", - " scoped_storage_options_creator=scoped_storage_options_creator,\n", - " )\n", - " )\n", - "):\n", + "ctx = ocp.Context()\n", + "ctx.array.saving.scoped_storage_options_creator = scoped_storage_options_creator\n", + "\n", + "with ctx:\n", " ocp.save(path / '2', pytree, overwrite=True)" ] }, @@ -885,13 +882,11 @@ " dtype=np.dtype(np.int16)\n", " )\n", ")\n", - "with ocp.Context(\n", - " array_options=ocp.options.ArrayOptions(\n", - " saving=ocp.options.ArrayOptions.Saving(\n", - " scoped_storage_options_creator=scoped_storage_options_creator\n", - " )\n", - " )\n", - "):\n", + "\n", + "ctx = ocp.Context()\n", + "ctx.array.saving.scoped_storage_options_creator = scoped_storage_options_creator\n", + "\n", + "with ctx:\n", " ocp.save(path / '3', pytree, overwrite=True)" ] }, @@ -943,13 +938,9 @@ }, "outputs": [], "source": [ - "with ocp.Context(\n", - " array_options=ocp.options.ArrayOptions(\n", - " saving=ocp.options.ArrayOptions.Saving(\n", - " use_ocdbt=True,\n", - " )\n", - " )\n", - "):\n", + "ctx = ocp.Context()\n", + "ctx.array.saving.use_ocdbt = True\n", + "with ctx:\n", " ocp.save(path / '4', pytree, overwrite=True)" ] }, @@ -990,13 +981,9 @@ }, "outputs": [], "source": [ - "with ocp.Context(\n", - " array_options=ocp.options.ArrayOptions(\n", - " saving=ocp.options.ArrayOptions.Saving(\n", - " use_ocdbt=False,\n", - " )\n", - " )\n", - "):\n", + "ctx = ocp.Context()\n", + "ctx.array.saving.use_ocdbt = False\n", + "with ctx:\n", " ocp.save(path / '5', pytree, overwrite=True)\n", "\n", "!ls /tmp/checkpointing-pytrees/advanced/5/pytree" @@ -1097,13 +1084,10 @@ }, "outputs": [], "source": [ - "with ocp.Context(\n", - " array_options=ocp.options.ArrayOptions(\n", - " loading=ocp.options.ArrayOptions.Loading(\n", - " enable_padding_and_truncation=True\n", - " )\n", - " )\n", - "):\n", + "\n", + "ctx = ocp.Context()\n", + "ctx.array.loading.enable_padding_and_truncation = True\n", + "with ctx:\n", " loaded = ocp.load(path / '1', different_shape_abstract_state)" ] }, diff --git a/docs/guides/checkpoint/v1/model_surgery.ipynb b/docs/guides/checkpoint/v1/model_surgery.ipynb index 2c0379bb1..b92d228b9 100644 --- a/docs/guides/checkpoint/v1/model_surgery.ipynb +++ b/docs/guides/checkpoint/v1/model_surgery.ipynb @@ -267,13 +267,9 @@ " print(e)\n", "\n", "# So partial_load must be opted-into\n", - "with ocp.Context(\n", - " pytree_options=ocp.options.PyTreeOptions(\n", - " loading=ocp.options.PyTreeOptions.Loading(\n", - " partial_load=True,\n", - " ),\n", - " ),\n", - "):\n", + "ctx = ocp.Context()\n", + "ctx.pytree.loading.partial_load = True\n", + "with ctx:\n", " ocp.load(path, abstract_tree)" ], "outputs": [], diff --git a/docs/guides/checkpoint/v1/partial_saving.ipynb b/docs/guides/checkpoint/v1/partial_saving.ipynb index 885c08b5c..0dfead0ab 100644 --- a/docs/guides/checkpoint/v1/partial_saving.ipynb +++ b/docs/guides/checkpoint/v1/partial_saving.ipynb @@ -318,12 +318,9 @@ "abstract_params = jax.tree.map(\n", " ocp.arrays.to_shape_dtype_struct, {'params': base_model_state['params']}\n", ")\n", - "\n", - "with ocp.Context(\n", - " pytree_options=ocp.options.PyTreeOptions(\n", - " loading=ocp.options.PyTreeOptions.Loading(partial_load=True)\n", - " )\n", - "):\n", + "ctx = ocp.Context()\n", + "ctx.pytree.loading.partial_load = True\n", + "with ctx:\n", " loaded_params = ocp.load(base_path, abstract_params)" ] },