diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 43860f8f0..924bbec98 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -23,14 +23,12 @@ import asyncio from collections.abc import Set -import contextlib import dataclasses import functools import json -import sys import threading import time -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, List, Mapping, Optional, Tuple, Union import uuid from absl import logging @@ -53,6 +51,7 @@ from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.path import types as path_types from orbax.checkpoint._src.path.snapshot import snapshot +from orbax.checkpoint._src.serialization import async_io_engine from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import memory_regulator from orbax.checkpoint._src.serialization import ocdbt_utils @@ -68,26 +67,23 @@ PyTree = Any -TupleKey = Tuple[str, ...] RestoreArgs = type_handlers.RestoreArgs ArrayRestoreArgs = type_handlers.ArrayRestoreArgs SaveArgs = type_handlers.SaveArgs ParamInfo = types.ParamInfo -TypeHandler = types.TypeHandler TypeHandlerRegistry = types.TypeHandlerRegistry +BatchRequest = async_io_engine.BatchRequest +BatchRequests = async_io_engine.BatchRequests +CommitFutures = async_io_engine.CommitFutures # TODO(b/298487158) Clean up protected access. -LimitInFlightBytes = limits.LimitInFlightBytes CheckpointArgs = checkpoint_args.CheckpointArgs register_with_handler = checkpoint_args.register_with_handler get_param_names = tree_utils.get_param_names PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE PLACEHOLDER = type_handlers.PLACEHOLDER -PLACEHOLDER_TYPESTR = type_handlers.PLACEHOLDER_TYPESTR TMP_DIR_SUFFIX = atomicity_types.TMP_DIR_SUFFIX -PENDING_DIR_SUFFIX = snapshot.PENDING_DIR_SUFFIX -DEFAULT_CONCURRENT_GB = 96 @@ -99,110 +95,12 @@ class PartialSaveReplacementError(PartialSaveError): """Raised when a replacement is attempted during partial saving.""" -def _default_sizeof_values(values: Sequence[Any]) -> Sequence[int]: - return [sys.getsizeof(v) for v in values] - - -def _get_batch_memory_size( - handler: TypeHandler, values: Sequence[Any] -) -> Tuple[int, int]: - """Gets memory size for a batch of leaf values.""" - try: - write_sizes, read_sizes = zip(*handler.memory_size(values)) - except NotImplementedError: - logging.warning( - '`memory_size` is not implemented for `TypeHandler` of type: %s. Using' - ' the a default implementation to measure value memory consumption that' - ' may result in inaccurate estimation.', - type(handler), - ) - write_sizes = read_sizes = _default_sizeof_values(values) - assert len(write_sizes) == len(values) - assert len(read_sizes) == len(values) - return sum(write_sizes), sum(read_sizes) - - -def _log_io_metrics( - size: int, - start_time: float, - gbytes_per_sec_metric: str, - gbytes_metric: Optional[str] = None, -): - """Logs the bytes per second metric.""" - time_elapsed = time.time() - start_time - bytes_per_sec = ( - float('nan') if time_elapsed == 0 else float(size) / time_elapsed - ) - note = 'per-host' - logging.info( - '[process=%d] %s: %s/s (total gbytes: %s) (time elapsed: %s s) (%s)', - multihost.process_index(), - gbytes_per_sec_metric, - humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'), - humanize.naturalsize(size, binary=True), - time_elapsed, - note, - ) - jax.monitoring.record_scalar( - gbytes_per_sec_metric, value=bytes_per_sec / (1024**3) - ) - if gbytes_metric is not None: - jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3)) - - -async def _logging_serialize( - handler: TypeHandler, - serialize: asyncio.Coroutine[Any, Any, Sequence[future.Future]], -) -> Sequence[future.Future]: - """Logs the time taken to serialize.""" - start = time.time() - commit_futures = await serialize - handler_name = f'{type(handler).__module__}.{type(handler).__qualname__}' - logging.info( - '[process=%s][thread=%s] Initiated %s.serialize. Time taken: %fs', - multihost.process_index(), - threading.current_thread().name, - f'"{handler_name}"', - time.time() - start, - ) - return commit_futures - - -@dataclasses.dataclass -class _BatchRequest: - """Represents a a request for batched serialization or deserialization. - - Attributes: - handler: Used to serialize or deserialize the parameters. - keys: Used to identify the original tree keys so that the PyTree can be - reconstructed. - values: Values to serialize. - infos: ParamInfos. - args: List of SaveArgs or RestoreArgs. - """ - - handler: TypeHandler - keys: List[str] - values: List[Any] - infos: List[ParamInfo] - args: List[Union[SaveArgs, RestoreArgs]] - - def __post_init__(self): - length = len(self.values) - if not all(( - length == len(self.infos), - length == len(self.args), - length == len(self.keys), - )): - raise AssertionError('Found `_BatchRequest` with mismatched parameters.') - - def batched_serialization_requests( tree: PyTree, param_infos: PyTree, args: PyTree, registry: TypeHandlerRegistry, -) -> List[_BatchRequest]: +) -> BatchRequests: """Gets a list of batched serialization or deserialization requests.""" grouped = {} @@ -251,7 +149,7 @@ def _group_value( ) from e if handler not in grouped: - grouped[handler] = _BatchRequest(handler, [], [], [], []) + grouped[handler] = BatchRequest(handler, [], [], [], []) request = grouped[handler] grouped[handler] = dataclasses.replace( request, @@ -330,17 +228,6 @@ def _maybe_set_default_save_restore_args(v, leaf_args): ) -@contextlib.contextmanager -def _memory_profiler_context(): - """Context manager for memory_regulator profiler.""" - memory_regulator.profiler_start() - try: - yield - finally: - # Explicitly stop the bg thread if an exception occurs - memory_regulator.profiler_end() - - def _format_bytes(bytes_value: Optional[int]) -> str: @@ -453,9 +340,9 @@ def _validate_key(key, merged_tuples_set=merged_tuples_set): def _filter_batch_requests( - batch_requests: Sequence[_BatchRequest], + batch_requests: BatchRequests, additions: Set[Any], -) -> list[_BatchRequest]: +) -> BatchRequests: """Filters batch requests to include only items matching the additions.""" filtered_requests = [] for request in batch_requests: @@ -624,6 +511,7 @@ def __init__( _format_bytes(self._save_concurrent_bytes), _format_bytes(self._restore_concurrent_bytes), ) + self._async_io_engine = async_io_engine.AsyncIoEngine() def get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" @@ -702,36 +590,13 @@ async def _async_partial_save( self, directory: epath.Path, item: PyTree, - batch_requests: list[_BatchRequest], - param_infos: PyTree, - save_args: BasePyTreeSaveArgs, - ) -> Tuple[ - List[asyncio.Coroutine[Any, Any, Sequence[future.Future]]], - int, - PyTree, - BasePyTreeSaveArgs, - ]: + batch_requests: BatchRequests, + ) -> BatchRequests: flat_item = tree_utils.to_flat_dict(item) additions = await _get_partial_save_additions( directory, flat_item, self._pytree_metadata_options ) - filtered_requests = _filter_batch_requests(batch_requests, additions) - - serialize_ops = [] - tree_memory_size = 0 - for request in filtered_requests: - serialize_ops += [ - _logging_serialize( - request.handler, - request.handler.serialize( - request.values, request.infos, request.args - ), - ) - ] - write_size, _ = _get_batch_memory_size(request.handler, request.values) - tree_memory_size += write_size - - return serialize_ops, tree_memory_size, param_infos, save_args + return _filter_batch_requests(batch_requests, additions) async def async_save( self, @@ -815,7 +680,6 @@ async def async_save( leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos) ) - serialize_ops = [] # List of (coros -> List of futures) batch_requests = batched_serialization_requests( item, param_infos, @@ -824,34 +688,17 @@ async def async_save( ) batch_requests_ready_time = time.time() - with _memory_profiler_context(): - if args.partial_save_mode: - serialize_ops, tree_memory_size, param_infos, save_args = ( - await self._async_partial_save( - directory, item, batch_requests, param_infos, save_args - ) - ) - else: - tree_memory_size = 0 - for request in batch_requests: - serialize_ops += [ - _logging_serialize( - request.handler, - request.handler.serialize( - request.values, request.infos, request.args - ), - ) - ] - write_size, _ = _get_batch_memory_size( - request.handler, request.values - ) - tree_memory_size += write_size - # Await copy futures. Returns List[List[future.Future]]. - commit_futures = await asyncio.gather(*serialize_ops) - logging.info( - 'MemoryRegulated: Peak usage: %f GiB', - memory_regulator.profiler_peak_usage_gib(), + if args.partial_save_mode: + requests_to_save = await self._async_partial_save( + directory, item, batch_requests + ) + else: + requests_to_save = batch_requests + + tree_memory_size = async_io_engine.compute_save_memory_size( + requests_to_save ) + commit_futures = await self._async_io_engine.execute_save(requests_to_save) # Flatten to List[future.Future]. commit_futures, _ = jax.tree.flatten(commit_futures) @@ -879,7 +726,7 @@ async def async_save( save_futures += commit_futures - _log_io_metrics( + async_io_engine.log_io_metrics( tree_memory_size, start_time, '/jax/orbax/write/blocking_gbytes_per_sec', @@ -888,7 +735,7 @@ async def async_save( future.ChainedFuture( save_futures, functools.partial( - _log_io_metrics, + async_io_engine.log_io_metrics, tree_memory_size, start_time, '/jax/orbax/write/gbytes_per_sec', @@ -953,21 +800,19 @@ async def _maybe_deserialize( restore_args, self._type_handler_registry, ) - deserialized_batches = [] - deserialized_batches_ops = [] - for request in batch_requests: - deserialized_batches_ops.append( - request.handler.deserialize(request.infos, request.args) - ) - deserialized_batches += await asyncio.gather(*deserialized_batches_ops) - tree_memory_size = 0 + deserialized_batches = await self._async_io_engine.execute_restore( + batch_requests + ) + tree_memory_size = async_io_engine.compute_restore_memory_size( + batch_requests, deserialized_batches + ) + flat_restored = {} for request, deserialized in zip(batch_requests, deserialized_batches): - _, read_size = _get_batch_memory_size(request.handler, deserialized) - tree_memory_size += read_size for key, value in zip(request.keys, deserialized): flat_restored[key] = value + # Add in empty nodes from the metadata tree. for key in flat_metadata.keys(): if key not in flat_restored: @@ -977,6 +822,7 @@ async def _maybe_deserialize( flat_restored[key] = empty_values.get_empty_value_from_typestr( flat_metadata[key].value_type, self._pytree_metadata_options ) + # Restore using `item` as the target structure. If there are any custom # nodes (e.g. optax.EmptyState), these will replace None values in # flat_restored. @@ -1260,7 +1106,7 @@ class TrainState: ) - _log_io_metrics( + async_io_engine.log_io_metrics( tree_memory_size, start_time, '/jax/checkpoint/read/gbytes_per_sec', diff --git a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py index 0c5de7e4f..e8613e626 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py @@ -21,7 +21,6 @@ from __future__ import annotations -import asyncio import dataclasses import json import re @@ -44,6 +43,7 @@ from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib from orbax.checkpoint._src.metadata import empty_values from orbax.checkpoint._src.metadata import tree as tree_metadata +from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.path import types as path_types from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils @@ -76,13 +76,12 @@ ) BasePyTreeSaveArgs = base_pytree_checkpoint_handler.BasePyTreeSaveArgs BasePyTreeRestoreArgs = base_pytree_checkpoint_handler.BasePyTreeRestoreArgs -LimitInFlightBytes = base_pytree_checkpoint_handler.LimitInFlightBytes -get_param_names = base_pytree_checkpoint_handler.get_param_names +LimitInFlightBytes = limits.LimitInFlightBytes +get_param_names = tree_utils.get_param_names -PYTREE_METADATA_FILE = base_pytree_checkpoint_handler.PYTREE_METADATA_FILE +PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE _CHECKPOINT_FILE = 'checkpoint' -_METADATA_FILE = PYTREE_METADATA_FILE -DEFAULT_CONCURRENT_GB = base_pytree_checkpoint_handler.DEFAULT_CONCURRENT_GB +DEFAULT_CONCURRENT_GB = 96 def _maybe_set_default_restore_args(args): @@ -714,13 +713,9 @@ def _process_aggregated_value(meta_or_value, args): self._type_handler_registry, ) ) - deserialized_batches = [] - deserialized_batches_ops = [] - for request in batch_requests: - deserialized_batches_ops.append( - request.handler.deserialize(request.infos, request.args) - ) - deserialized_batches += await asyncio.gather(*deserialized_batches_ops) + deserialized_batches = await self._handler_impl._async_io_engine.execute_restore( # pylint: disable=protected-access + batch_requests + ) flat_restored = {} for request, deserialized in zip(batch_requests, deserialized_batches): diff --git a/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py b/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py new file mode 100644 index 000000000..e28f056f5 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine.py @@ -0,0 +1,232 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIoEngine module. + +This module encapsulates the concurrency and execution orchestration layers for +Orbax. +Its primary responsibility is managing how work is dispatched to the Python +`asyncio` event loop and thread pools. + +Scope: +* `asyncio.gather` and future management. +* Concurrency gating (e.g., `ByteLimiter`, `MemoryRegulator`). +* Top-level I/O telemetry and performance logging (e.g., throughput +calculation). + +Anti-Scope (What does NOT belong here): +* Storage Backend Logic: Low-level serialization drivers, TensorStore +bindings. +* PyTree Math: Structural diffing, tree traversal, and `ParamInfo` +generation. +* Metadata Persistence: File-system JSON writes and Descriptor +management. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import dataclasses +import sys +import threading +import time +from typing import Any, List, Optional, Sequence, Tuple, Union + +from absl import logging +import humanize +import jax +from orbax.checkpoint._src.futures import future +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.serialization import memory_regulator +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.serialization import types + +TypeHandler = types.TypeHandler +ParamInfo = types.ParamInfo +SaveArgs = type_handlers.SaveArgs +RestoreArgs = type_handlers.RestoreArgs +BatchOfLeaves = Sequence[Any] +BatchOfInts = Sequence[int] +Batches = Sequence[BatchOfLeaves] +CommitFutures = Sequence[future.Future] +MemorySizes = Tuple[int, int] + + +def _default_sizeof_values(values: BatchOfInts) -> BatchOfInts: + return [sys.getsizeof(v) for v in values] + + +def get_batch_memory_size( + handler: TypeHandler, values: BatchOfLeaves +) -> MemorySizes: + """Gets memory size for a batch of leaf values.""" + try: + write_sizes, read_sizes = zip(*handler.memory_size(values)) + except NotImplementedError: + logging.warning( + '`memory_size` is not implemented for `TypeHandler` of type: %s. Using' + ' the a default implementation to measure value memory consumption that' + ' may result in inaccurate estimation.', + type(handler), + ) + write_sizes = read_sizes = _default_sizeof_values(values) + assert len(write_sizes) == len(values) + assert len(read_sizes) == len(values) + return sum(write_sizes), sum(read_sizes) + + +def log_io_metrics( + size: int, + start_time: float, + gbytes_per_sec_metric: str, + gbytes_metric: Optional[str] = None, +): + """Logs the bytes per second metric.""" + time_elapsed = time.time() - start_time + bytes_per_sec = ( + float('nan') if time_elapsed == 0 else float(size) / time_elapsed + ) + note = 'per-host' + logging.info( + '[process=%d] %s: %s/s (total gbytes: %s) (time elapsed: %s s) (%s)', + multihost.process_index(), + gbytes_per_sec_metric, + humanize.naturalsize(bytes_per_sec, binary=True, format='%.3f'), + humanize.naturalsize(size, binary=True), + time_elapsed, + note, + ) + jax.monitoring.record_scalar( + gbytes_per_sec_metric, value=bytes_per_sec / (1024**3) + ) + if gbytes_metric is not None: + jax.monitoring.record_scalar(gbytes_metric, value=size / (1024**3)) + + +async def logging_serialize( + handler: TypeHandler, + serialize: asyncio.Coroutine[Any, Any, CommitFutures], +) -> CommitFutures: + """Logs the time taken to serialize.""" + start = time.time() + commit_futures = await serialize + handler_name = f'{type(handler).__module__}.{type(handler).__qualname__}' + logging.info( + '[process=%s][thread=%s] Initiated %s.serialize. Time taken: %fs', + multihost.process_index(), + threading.current_thread().name, + f'"{handler_name}"', + time.time() - start, + ) + return commit_futures + + +@dataclasses.dataclass +class BatchRequest: + """Represents a a request for batched serialization or deserialization. + + Attributes: + handler: Used to serialize or deserialize the parameters. + keys: Used to identify the original tree keys so that the PyTree can be + reconstructed. + values: Values to serialize. + infos: ParamInfos. + args: List of SaveArgs or RestoreArgs. + """ + + handler: TypeHandler + keys: List[str] + values: List[Any] + infos: List[ParamInfo] + args: List[Union[SaveArgs, RestoreArgs]] + + def __post_init__(self): + length = len(self.values) + if not all(( + length == len(self.infos), + length == len(self.args), + length == len(self.keys), + )): + raise AssertionError('Found `_BatchRequest` with mismatched parameters.') + + +BatchRequests = Sequence[BatchRequest] + + +@contextlib.contextmanager +def memory_profiler_context(): + """Context manager for memory_regulator profiler.""" + memory_regulator.profiler_start() + try: + yield + finally: + # Explicitly stop the bg thread if an exception occurs + memory_regulator.profiler_end() + + +def compute_save_memory_size(batch_requests: BatchRequests) -> int: + """Computes the total write memory size for a sequence of batch requests.""" + tree_memory_size = 0 + for request in batch_requests: + write_size, _ = get_batch_memory_size(request.handler, request.values) + tree_memory_size += write_size + return tree_memory_size + + +def compute_restore_memory_size( + batch_requests: BatchRequests, + deserialized_batches: Batches, +) -> int: + """Computes the total read memory size for deserialized batches.""" + tree_memory_size = 0 + for request, deserialized in zip(batch_requests, deserialized_batches): + _, read_size = get_batch_memory_size(request.handler, deserialized) + tree_memory_size += read_size + return tree_memory_size + + +class AsyncIoEngine: + """Encapsulates concurrency, thread-pooling, and I/O telemetry logic.""" + + async def execute_save(self, batch_requests: BatchRequests) -> CommitFutures: + """Executes save requests asynchronously.""" + serialize_ops = [] + with memory_profiler_context(): + for request in batch_requests: + serialize_ops.append( + logging_serialize( + request.handler, + request.handler.serialize( + request.values, request.infos, request.args + ), + ) + ) + commit_futures = await asyncio.gather(*serialize_ops) + + logging.info( + 'MemoryRegulated: Peak usage: %f GiB', + memory_regulator.profiler_peak_usage_gib(), + ) + return commit_futures + + async def execute_restore(self, batch_requests: BatchRequests) -> Batches: + """Executes restore requests asynchronously.""" + deserialized_batches_ops = [] + for request in batch_requests: + deserialized_batches_ops.append( + request.handler.deserialize(request.infos, request.args) + ) + deserialized_batches = await asyncio.gather(*deserialized_batches_ops) + return deserialized_batches diff --git a/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine_test.py b/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine_test.py new file mode 100644 index 000000000..29db8ea24 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/async_io_engine_test.py @@ -0,0 +1,216 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from unittest import mock + +from absl.testing import absltest +from orbax.checkpoint._src.serialization import async_io_engine +from orbax.checkpoint._src.serialization import types + +AsyncIoEngine = async_io_engine.AsyncIoEngine +BatchRequest = async_io_engine.BatchRequest + + +class AsyncIoEngineTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): + + def test_get_batch_memory_size_success(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + handler.memory_size.return_value = [(10, 20), (30, 40)] + + write_size, read_size = async_io_engine.get_batch_memory_size( + handler, ['a', 'b'] + ) + self.assertEqual(write_size, 40) + self.assertEqual(read_size, 60) + + def test_get_batch_memory_size_not_implemented(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + handler.memory_size.side_effect = NotImplementedError() + + values = ['dummy1', 'dummy2'] + expected_size = sum(sys.getsizeof(v) for v in values) + + write_size, read_size = async_io_engine.get_batch_memory_size( + handler, values + ) + self.assertEqual(write_size, expected_size) + self.assertEqual(read_size, expected_size) + + def test_batch_request_validation_success(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + req = BatchRequest( + handler=handler, + keys=['k1', 'k2'], + values=['v1', 'v2'], + infos=[mock.Mock(), mock.Mock()], + args=[mock.Mock(), mock.Mock()], + ) + self.assertLen(req.values, 2) + + def test_batch_request_validation_mismatch(self): + handler = mock.create_autospec(types.TypeHandler, instance=True) + with self.assertRaises(AssertionError): + BatchRequest( + handler=handler, + keys=['k1'], + values=['v1', 'v2'], + infos=[mock.Mock(), mock.Mock()], + args=[mock.Mock(), mock.Mock()], + ) + + def test_compute_save_memory_size(self): + handler1 = mock.create_autospec(types.TypeHandler, instance=True) + handler2 = mock.create_autospec(types.TypeHandler, instance=True) + + # memory_size returns a list of (write_size, read_size) tuples + handler1.memory_size.return_value = [(100, 0)] + handler2.memory_size.return_value = [(200, 0)] + + req1 = BatchRequest( + handler=handler1, + keys=['k1'], + values=['v1'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + req2 = BatchRequest( + handler=handler2, + keys=['k2'], + values=['v2'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + + tree_memory_size = async_io_engine.compute_save_memory_size([req1, req2]) + self.assertEqual(tree_memory_size, 300) + + def test_compute_restore_memory_size(self): + handler1 = mock.create_autospec(types.TypeHandler, instance=True) + handler2 = mock.create_autospec(types.TypeHandler, instance=True) + + # memory_size returns a list of (write_size, read_size) tuples + handler1.memory_size.return_value = [(0, 50)] + handler2.memory_size.return_value = [(0, 150)] + + req1 = BatchRequest( + handler=handler1, + keys=['k1'], + values=['v1'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + req2 = BatchRequest( + handler=handler2, + keys=['k2'], + values=['v2'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + + deserialized_batches = [['restored1'], ['restored2']] + + tree_memory_size = async_io_engine.compute_restore_memory_size( + [req1, req2], deserialized_batches + ) + self.assertEqual(tree_memory_size, 200) + + async def test_execute_save(self): + engine = AsyncIoEngine() + + handler1 = mock.create_autospec(types.TypeHandler, instance=True) + handler2 = mock.create_autospec(types.TypeHandler, instance=True) + + async def dummy_serialize1(*args, **kwargs): + del args, kwargs + return ['fut1', 'fut2'] + + async def dummy_serialize2(*args, **kwargs): + del args, kwargs + return ['fut3'] + + handler1.serialize.side_effect = dummy_serialize1 + handler2.serialize.side_effect = dummy_serialize2 + + req1 = BatchRequest( + handler=handler1, + keys=['k1'], + values=['v1'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + req2 = BatchRequest( + handler=handler2, + keys=['k2'], + values=['v2'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + + commit_futures = await engine.execute_save([req1, req2]) + self.assertEqual(commit_futures, [['fut1', 'fut2'], ['fut3']]) + + # Test the standalone memory size function + handler1.memory_size.return_value = [(100, 0)] + handler2.memory_size.return_value = [(200, 0)] + tree_memory_size = async_io_engine.compute_save_memory_size([req1, req2]) + self.assertEqual(tree_memory_size, 300) + + async def test_execute_restore(self): + engine = AsyncIoEngine() + + handler1 = mock.create_autospec(types.TypeHandler, instance=True) + handler2 = mock.create_autospec(types.TypeHandler, instance=True) + + async def dummy_deserialize1(*args, **kwargs): + del args, kwargs + return ['restored1'] + + async def dummy_deserialize2(*args, **kwargs): + del args, kwargs + return ['restored2'] + + handler1.deserialize.side_effect = dummy_deserialize1 + handler2.deserialize.side_effect = dummy_deserialize2 + + req1 = BatchRequest( + handler=handler1, + keys=['k1'], + values=['v1'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + req2 = BatchRequest( + handler=handler2, + keys=['k2'], + values=['v2'], + infos=[mock.Mock()], + args=[mock.Mock()], + ) + + deserialized_batches = await engine.execute_restore([req1, req2]) + self.assertEqual(deserialized_batches, [['restored1'], ['restored2']]) + + # Test the standalone memory size function + handler1.memory_size.return_value = [(0, 50)] + handler2.memory_size.return_value = [(0, 150)] + tree_memory_size = async_io_engine.compute_restore_memory_size( + [req1, req2], deserialized_batches + ) + self.assertEqual(tree_memory_size, 200) + + +if __name__ == '__main__': + absltest.main()