Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions checkpoint/orbax/checkpoint/_src/testing/benchmarks/v1/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,31 +89,27 @@ def is_valid(self) -> bool:

@property
def context(self) -> ocp.Context:
return ocp.Context(
array_options=ocp.options.ArrayOptions(
saving=ocp.options.ArrayOptions.Saving(
storage_options=ocp.options.ArrayOptions.Saving.StorageOptions(
chunk_byte_size=self.chunk_byte_size,
),
use_ocdbt=self.use_ocdbt,
use_zarr3=self.use_zarr3,
use_replica_parallel=self.use_replica_parallel,
use_compression=self.use_compression,
enable_replica_parallel_separate_folder=self.enable_replica_parallel_separate_folder,
),
loading=ocp.options.ArrayOptions.Loading(
use_load_and_broadcast=self.use_load_and_broadcast,
),
),
memory_options=ocp.options.MemoryOptions(
write_concurrent_bytes=self.save_concurrent_gb * 1024**3
if self.save_concurrent_gb is not None
else None,
read_concurrent_bytes=self.restore_concurrent_gb * 1024**3
if self.restore_concurrent_gb is not None
else None,
),
ctx = ocp.Context()
ctx.array.saving.storage_options.chunk_byte_size = self.chunk_byte_size
ctx.array.saving.use_ocdbt = self.use_ocdbt
ctx.array.saving.use_zarr3 = self.use_zarr3
ctx.array.saving.use_replica_parallel = self.use_replica_parallel
ctx.array.saving.use_compression = self.use_compression
ctx.array.saving.enable_replica_parallel_separate_folder = (
self.enable_replica_parallel_separate_folder
)
ctx.array.loading.use_load_and_broadcast = self.use_load_and_broadcast
ctx.memory.write_concurrent_bytes = (
self.save_concurrent_gb * 1024**3
if self.save_concurrent_gb is not None
else None
)
ctx.memory.read_concurrent_bytes = (
self.restore_concurrent_gb * 1024**3
if self.restore_concurrent_gb is not None
else None
)
return ctx


def clear_pytree(pytree: Any) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,26 @@


def get_context(default: Context | None = None) -> Context:
"""Returns the current `Context` or `default` or `Context()` if not set."""
"""Returns the currently active `Context`, or a default if no context is active.

If called within a `with ocp.Context(...)` block, this function returns the
`Context` object associated with that block (the active context).

If called outside of any `with` block, this function returns `default`
if it is provided. If `default` is not provided or `None`, it returns a
new `Context` instance initialized with default options.

Note: If a context is active, the `default` parameter is ignored, and the
active context is always returned. To ensure that an explicitly provided
context takes precedence over any active context, use the pattern:
`ctx = explicit_context if explicit_context is not None else get_context()`.

Args:
default: A `Context` object to return if no context is active.

Returns:
The active `Context` or a default `Context`.
"""
default = default or Context()
return _CONTEXT.get(default)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import contextvars
import dataclasses
import enum
from typing import Any, Callable, Protocol, Type
from typing import Any, Callable, Protocol

from etils import epath
import numpy as np
Expand All @@ -29,7 +29,6 @@
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.serialization import pathways_types
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.path import types as path_types
from orbax.checkpoint.experimental.v1._src.serialization import types as serialization_types
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types
Expand Down Expand Up @@ -495,19 +494,6 @@ class CheckpointablesOptions(_ActiveContextGuard):
)
)

@classmethod
def create_with_handlers(
cls,
*handlers: Type[handler_types.CheckpointableHandler],
**named_handlers: Type[handler_types.CheckpointableHandler],
) -> CheckpointablesOptions:
registry = registration.local_registry(include_global_registry=True)
for handler in handlers:
registry.add(handler, checkpointable_name=None)
for name, handler in named_handlers.items():
registry.add(handler, checkpointable_name=name)
return cls(registry=registry)


@dataclasses.dataclass(kw_only=True)
class PathwaysOptions(_ActiveContextGuard):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ class JsonHandler(CheckpointableHandler[JsonType, None]):

config = {'learning_rate': 0.01, 'batch_size': 32}

checkpointables_options = (
ocp.options.CheckpointablesOptions.create_with_handlers(
experiment_config=ocp.handlers.JsonHandler(
filename='experiment_config.json'
)
)
registry = ocp.handlers.local_registry()
registry.add(
ocp.handlers.JsonHandler,
checkpointable_name='experiment_config',
)
with ocp.Context(checkpointables_options=checkpointables_options):
ctx = ocp.Context()
ctx.checkpointables.registry = registry
with ctx:
ocp.save_checkpointables(path, dict(experiment_config=config))

Attributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ class ProtoHandler(
# Assuming MyProtoMessage is your compiled protobuf class
my_proto_msg = MyProtoMessage(config_field="value")

checkpointables_options = (
ocp.options.CheckpointablesOptions.create_with_handlers(
proto_config=ocp.handlers.ProtoHandler(
filename="model_config.pbtxt"
)
)
registry = ocp.handlers.local_registry()
registry.add(
ocp.handlers.ProtoHandler, checkpointable_name="proto_config"
)
with ocp.Context(checkpointables_options=checkpointables_options):
ctx = ocp.Context()
ctx.checkpointables.registry = registry
with ctx:
ocp.save_checkpointables(path, dict(proto_config=my_proto_msg))

Attributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,13 @@ class PyTreeHandler(CheckpointableHandler[PyTree, PyTree]):

state_pytree = {'weights': [1.0, 2.0], 'bias': 0.0}

checkpointables_options = (
ocp.options.CheckpointablesOptions.create_with_handlers(
model_state=ocp.handlers.PyTreeHandler()
)
registry = ocp.handlers.local_registry()
registry.add(
ocp.handlers.PyTreeHandler, checkpointable_name='model_state'
)
with ocp.Context(checkpointables_options=checkpointables_options):
ctx = ocp.Context()
ctx.checkpointables.registry = registry
with ctx:
ocp.save_checkpointables(path, dict(model_state=state_pytree))

Attributes:
Expand All @@ -299,7 +300,7 @@ def __init__(
) = None,
partial_save_mode: bool = False,
):
context = context_lib.get_context(context)
context = context if context is not None else context_lib.get_context()
self._context = context
self._multiprocessing_options = context.multiprocessing_options
self._partial_save_mode = partial_save_mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def handler_with_options(
options_lib.ArrayOptions.Saving.ScopedStorageOptionsCreator | None
) = None,
array_storage_options: (
options_lib.ArrayOptions.Saving.StorageOptions
) = options_lib.ArrayOptions.Saving.StorageOptions(),
options_lib.ArrayOptions.Saving.StorageOptions | None
) = None,
save_concurrent_bytes: int | None = None,
restore_concurrent_bytes: int | None = None,
use_ocdbt: bool = True,
Expand All @@ -271,48 +271,45 @@ def handler_with_options(
) = None,
):
"""Registers handlers with OCDBT support and resets when done."""
context = context_lib.Context(
array_options=options_lib.ArrayOptions(
saving=options_lib.ArrayOptions.Saving(
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
use_compression=use_compression,
ocdbt_target_data_file_size=ocdbt_target_data_file_size,
enable_pinned_host_transfer=enable_pinned_host_transfer,
array_metadata_store=array_metadata_store,
enable_write_sharding_file=enable_write_sharding_file,
use_replica_parallel=not utils.is_pathways_backend(),
storage_options=array_storage_options,
scoped_storage_options_creator=scoped_storage_options_creator,
),
loading=options_lib.ArrayOptions.Loading(
enable_padding_and_truncation=enable_padding_and_truncation,
),
),
memory_options=options_lib.MemoryOptions(
write_concurrent_bytes=save_concurrent_bytes,
read_concurrent_bytes=restore_concurrent_bytes,
),
pytree_options=options_lib.PyTreeOptions(
saving=options_lib.PyTreeOptions.Saving(
pytree_metadata_options=pytree_metadata_options,
),
loading=options_lib.PyTreeOptions.Loading(
partial_load=partial_load,
),
),
context = context_lib.Context()

context.array.saving.use_ocdbt = use_ocdbt
context.array.saving.use_zarr3 = use_zarr3
context.array.saving.use_compression = use_compression
context.array.saving.ocdbt_target_data_file_size = ocdbt_target_data_file_size
context.array.saving.enable_pinned_host_transfer = enable_pinned_host_transfer
context.array.saving.array_metadata_store = array_metadata_store
context.array.saving.enable_write_sharding_file = enable_write_sharding_file
context.array.saving.use_replica_parallel = not utils.is_pathways_backend()
if array_storage_options is not None:
context.array.saving.storage_options.dtype = array_storage_options.dtype
context.array.saving.storage_options.chunk_byte_size = (
array_storage_options.chunk_byte_size
)
context.array.saving.storage_options.shard_axes = (
array_storage_options.shard_axes
)
context.array.saving.scoped_storage_options_creator = (
scoped_storage_options_creator
)
context.array.loading.enable_padding_and_truncation = (
enable_padding_and_truncation
)

context.memory.write_concurrent_bytes = save_concurrent_bytes
context.memory.read_concurrent_bytes = restore_concurrent_bytes

context.pytree.saving.pytree_metadata_options = pytree_metadata_options
context.pytree.loading.partial_load = partial_load

handler = handler_test_utils.create_test_handler(
pytree_handler.PyTreeHandler,
context=context,
leaf_handler_registry=leaf_handler_registry,
)

try:
with context:
yield handler
finally:
pass


class PyTreeHandlerTest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@
# to a new v1 handler class.
registry.add(BazHandler, secondary_typestrs=['OldBazHandlerTypestr'])

checkpointables_options = ocp.options.CheckpointablesOptions(
registry=registry
)
with ocp.Context(checkpointables_options=checkpointables_options):
ctx = ocp.Context()
ctx.checkpointables.registry = registry
with ctx:
ocp.save_checkpointables(...)

Handler resolution for saving/loading follows this logic:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ async def get_checkpoint_layout(
f"Could not recognize the checkpoint at {path} as a valid"
f" {layout_enum.value} checkpoint. If you are trying to load a"
" checkpoint that does not conform to the standard Orbax format, use"
" `ocp.Context(layout=...)` to specify the expected checkpoint layout."
" `ctx.checkpoint_layout = ...` to specify the expected checkpoint"
" layout."
) from e


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.testing import multiprocess_test
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.layout import safetensors_layout
import safetensors.numpy

Expand Down Expand Up @@ -183,12 +182,12 @@ async def test_load_without_global_reshard_single_tensor(self):
}

layout = SafetensorsLayout()
with context_lib.Context(
safetensors_options=options_lib.SafetensorsOptions(
ignore_load_sharding=True
)
):
restore_fn = await layout.load(st_path, abstract_state=abstract_state)
ctx = context_lib.Context()
ctx.safetensors.ignore_load_sharding = True
with ctx:
restore_fn = await layout.load(
st_path, abstract_state=abstract_state
)
restored_pytree = await restore_fn
restored_tensor = restored_pytree["params.tensor"]

Expand Down Expand Up @@ -225,11 +224,9 @@ async def test_load_without_global_reshard_multi_tensor(self):
}

layout = SafetensorsLayout()
with context_lib.Context(
safetensors_options=options_lib.SafetensorsOptions(
ignore_load_sharding=True
)
):
ctx = context_lib.Context()
ctx.safetensors.ignore_load_sharding = True
with ctx:
restore_fn = await layout.load(st_path, abstract_state=abstract_state)
restored_pytree = await restore_fn

Expand Down Expand Up @@ -337,11 +334,9 @@ async def test_load_without_global_reshard_memory_efficiency(self):

tracemalloc.start()

with context_lib.Context(
safetensors_options=options_lib.SafetensorsOptions(
ignore_load_sharding=True
)
):
ctx = context_lib.Context()
ctx.safetensors.ignore_load_sharding = True
with ctx:
restore_fn = await layout.load(
file_path,
abstract_state=abstract_pytree,
Expand Down
Loading
Loading