Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@

import asyncio
from collections.abc import Set
import contextlib
import dataclasses
import functools
import json
import sys
import threading
import time
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, List, Mapping, Optional, Tuple, Union
import uuid

from absl import logging
Expand All @@ -53,6 +51,7 @@
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.path import types as path_types
from orbax.checkpoint._src.path.snapshot import snapshot
from orbax.checkpoint._src.serialization import async_io_engine
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import memory_regulator
from orbax.checkpoint._src.serialization import ocdbt_utils
Expand All @@ -68,26 +67,23 @@


PyTree = Any
TupleKey = Tuple[str, ...]
RestoreArgs = type_handlers.RestoreArgs
ArrayRestoreArgs = type_handlers.ArrayRestoreArgs
SaveArgs = type_handlers.SaveArgs
ParamInfo = types.ParamInfo
TypeHandler = types.TypeHandler
TypeHandlerRegistry = types.TypeHandlerRegistry
BatchRequest = async_io_engine.BatchRequest
BatchRequests = async_io_engine.BatchRequests
CommitFutures = async_io_engine.CommitFutures

# TODO(b/298487158) Clean up protected access.
LimitInFlightBytes = limits.LimitInFlightBytes
CheckpointArgs = checkpoint_args.CheckpointArgs
register_with_handler = checkpoint_args.register_with_handler
get_param_names = tree_utils.get_param_names
PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE
PLACEHOLDER = type_handlers.PLACEHOLDER
PLACEHOLDER_TYPESTR = type_handlers.PLACEHOLDER_TYPESTR
TMP_DIR_SUFFIX = atomicity_types.TMP_DIR_SUFFIX
PENDING_DIR_SUFFIX = snapshot.PENDING_DIR_SUFFIX

DEFAULT_CONCURRENT_GB = 96



Expand All @@ -99,110 +95,12 @@ class PartialSaveReplacementError(PartialSaveError):
"""Raised when a replacement is attempted during partial saving."""


def _default_sizeof_values(values: Sequence[Any]) -> Sequence[int]:
return [sys.getsizeof(v) for v in values]


def _get_batch_memory_size(
handler: TypeHandler, values: Sequence[Any]
) -> Tuple[int, int]:
"""Gets memory size for a batch of leaf values."""
try:
write_sizes, read_sizes = zip(*handler.memory_size(values))
except NotImplementedError:
logging.warning(
'`memory_size` is not implemented for `TypeHandler` of type: %s. Using'
' the a default implementation to measure value memory consumption that'
' may result in inaccurate estimation.',
type(handler),
)
write_sizes = read_sizes = _default_sizeof_values(values)
assert len(write_sizes) == len(values)
assert len(read_sizes) == len(values)
return sum(write_sizes), sum(read_sizes)


def _log_io_metrics(
size: int,
start_time: float,
gbytes_per_sec_metric: str,
gbytes_metric: Optional[str] = None,
):
"""Logs the bytes per second metric."""
time_elapsed = time.time() - start_time
bytes_per_sec = (
float('nan') if time_elapsed == 0 else float(size) / time_elapsed
)
note = 'per-host'
logging.info(
'[process=%d] %s: %s/s (total gbytes: %s) (time elapsed: %s s) (%s)',
multihost.process_index(),
gbytes_per_sec_metric,
humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'),
humanize.naturalsize(size, binary=True),
time_elapsed,
note,
)
jax.monitoring.record_scalar(
gbytes_per_sec_metric, value=bytes_per_sec / (1024**3)
)
if gbytes_metric is not None:
jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3))


async def _logging_serialize(
handler: TypeHandler,
serialize: asyncio.Coroutine[Any, Any, Sequence[future.Future]],
) -> Sequence[future.Future]:
"""Logs the time taken to serialize."""
start = time.time()
commit_futures = await serialize
handler_name = f'{type(handler).__module__}.{type(handler).__qualname__}'
logging.info(
'[process=%s][thread=%s] Initiated %s.serialize. Time taken: %fs',
multihost.process_index(),
threading.current_thread().name,
f'"{handler_name}"',
time.time() - start,
)
return commit_futures


@dataclasses.dataclass
class _BatchRequest:
"""Represents a a request for batched serialization or deserialization.

Attributes:
handler: Used to serialize or deserialize the parameters.
keys: Used to identify the original tree keys so that the PyTree can be
reconstructed.
values: Values to serialize.
infos: ParamInfos.
args: List of SaveArgs or RestoreArgs.
"""

handler: TypeHandler
keys: List[str]
values: List[Any]
infos: List[ParamInfo]
args: List[Union[SaveArgs, RestoreArgs]]

def __post_init__(self):
length = len(self.values)
if not all((
length == len(self.infos),
length == len(self.args),
length == len(self.keys),
)):
raise AssertionError('Found `_BatchRequest` with mismatched parameters.')


def batched_serialization_requests(
tree: PyTree,
param_infos: PyTree,
args: PyTree,
registry: TypeHandlerRegistry,
) -> List[_BatchRequest]:
) -> BatchRequests:
"""Gets a list of batched serialization or deserialization requests."""
grouped = {}

Expand Down Expand Up @@ -251,7 +149,7 @@ def _group_value(
) from e

if handler not in grouped:
grouped[handler] = _BatchRequest(handler, [], [], [], [])
grouped[handler] = BatchRequest(handler, [], [], [], [])
request = grouped[handler]
grouped[handler] = dataclasses.replace(
request,
Expand Down Expand Up @@ -330,17 +228,6 @@ def _maybe_set_default_save_restore_args(v, leaf_args):
)


@contextlib.contextmanager
def _memory_profiler_context():
"""Context manager for memory_regulator profiler."""
memory_regulator.profiler_start()
try:
yield
finally:
# Explicitly stop the bg thread if an exception occurs
memory_regulator.profiler_end()




def _format_bytes(bytes_value: Optional[int]) -> str:
Expand Down Expand Up @@ -453,9 +340,9 @@ def _validate_key(key, merged_tuples_set=merged_tuples_set):


def _filter_batch_requests(
batch_requests: Sequence[_BatchRequest],
batch_requests: BatchRequests,
additions: Set[Any],
) -> list[_BatchRequest]:
) -> BatchRequests:
"""Filters batch requests to include only items matching the additions."""
filtered_requests = []
for request in batch_requests:
Expand Down Expand Up @@ -624,6 +511,7 @@ def __init__(
_format_bytes(self._save_concurrent_bytes),
_format_bytes(self._restore_concurrent_bytes),
)
self._async_io_engine = async_io_engine.AsyncIoEngine()

def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
Expand Down Expand Up @@ -702,36 +590,13 @@ async def _async_partial_save(
self,
directory: epath.Path,
item: PyTree,
batch_requests: list[_BatchRequest],
param_infos: PyTree,
save_args: BasePyTreeSaveArgs,
) -> Tuple[
List[asyncio.Coroutine[Any, Any, Sequence[future.Future]]],
int,
PyTree,
BasePyTreeSaveArgs,
]:
batch_requests: BatchRequests,
) -> BatchRequests:
flat_item = tree_utils.to_flat_dict(item)
additions = await _get_partial_save_additions(
directory, flat_item, self._pytree_metadata_options
)
filtered_requests = _filter_batch_requests(batch_requests, additions)

serialize_ops = []
tree_memory_size = 0
for request in filtered_requests:
serialize_ops += [
_logging_serialize(
request.handler,
request.handler.serialize(
request.values, request.infos, request.args
),
)
]
write_size, _ = _get_batch_memory_size(request.handler, request.values)
tree_memory_size += write_size

return serialize_ops, tree_memory_size, param_infos, save_args
return _filter_batch_requests(batch_requests, additions)

async def async_save(
self,
Expand Down Expand Up @@ -815,7 +680,6 @@ async def async_save(
leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos)
)

serialize_ops = [] # List of (coros -> List of futures)
batch_requests = batched_serialization_requests(
item,
param_infos,
Expand All @@ -824,34 +688,17 @@ async def async_save(
)

batch_requests_ready_time = time.time()
with _memory_profiler_context():
if args.partial_save_mode:
serialize_ops, tree_memory_size, param_infos, save_args = (
await self._async_partial_save(
directory, item, batch_requests, param_infos, save_args
)
)
else:
tree_memory_size = 0
for request in batch_requests:
serialize_ops += [
_logging_serialize(
request.handler,
request.handler.serialize(
request.values, request.infos, request.args
),
)
]
write_size, _ = _get_batch_memory_size(
request.handler, request.values
)
tree_memory_size += write_size
# Await copy futures. Returns List[List[future.Future]].
commit_futures = await asyncio.gather(*serialize_ops)
logging.info(
'MemoryRegulated: Peak usage: %f GiB',
memory_regulator.profiler_peak_usage_gib(),
if args.partial_save_mode:
requests_to_save = await self._async_partial_save(
directory, item, batch_requests
)
else:
requests_to_save = batch_requests

tree_memory_size = async_io_engine.compute_save_memory_size(
requests_to_save
)
commit_futures = await self._async_io_engine.execute_save(requests_to_save)
# Flatten to List[future.Future].
commit_futures, _ = jax.tree.flatten(commit_futures)

Expand Down Expand Up @@ -879,7 +726,7 @@ async def async_save(
save_futures += commit_futures


_log_io_metrics(
async_io_engine.log_io_metrics(
tree_memory_size,
start_time,
'/jax/orbax/write/blocking_gbytes_per_sec',
Expand All @@ -888,7 +735,7 @@ async def async_save(
future.ChainedFuture(
save_futures,
functools.partial(
_log_io_metrics,
async_io_engine.log_io_metrics,
tree_memory_size,
start_time,
'/jax/orbax/write/gbytes_per_sec',
Expand Down Expand Up @@ -953,21 +800,19 @@ async def _maybe_deserialize(
restore_args,
self._type_handler_registry,
)
deserialized_batches = []
deserialized_batches_ops = []
for request in batch_requests:
deserialized_batches_ops.append(
request.handler.deserialize(request.infos, request.args)
)
deserialized_batches += await asyncio.gather(*deserialized_batches_ops)

tree_memory_size = 0
deserialized_batches = await self._async_io_engine.execute_restore(
batch_requests
)
tree_memory_size = async_io_engine.compute_restore_memory_size(
batch_requests, deserialized_batches
)

flat_restored = {}
for request, deserialized in zip(batch_requests, deserialized_batches):
_, read_size = _get_batch_memory_size(request.handler, deserialized)
tree_memory_size += read_size
for key, value in zip(request.keys, deserialized):
flat_restored[key] = value

# Add in empty nodes from the metadata tree.
for key in flat_metadata.keys():
if key not in flat_restored:
Expand All @@ -977,6 +822,7 @@ async def _maybe_deserialize(
flat_restored[key] = empty_values.get_empty_value_from_typestr(
flat_metadata[key].value_type, self._pytree_metadata_options
)

# Restore using `item` as the target structure. If there are any custom
# nodes (e.g. optax.EmptyState), these will replace None values in
# flat_restored.
Expand Down Expand Up @@ -1260,7 +1106,7 @@ class TrainState:
)


_log_io_metrics(
async_io_engine.log_io_metrics(
tree_memory_size,
start_time,
'/jax/checkpoint/read/gbytes_per_sec',
Expand Down
Loading
Loading