diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index cd6eb696c9..6c433ba2ab 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -6,6 +6,14 @@ AtomExcludeMask, PairExcludeMask, ) +from .lmdb_data import ( + DistributedSameNlocBatchSampler, + LmdbDataReader, + LmdbTestData, + SameNlocBatchSampler, + is_lmdb, + make_neighbor_stat_data, +) from .network import ( EmbeddingNet, FittingNet, @@ -44,13 +52,17 @@ __all__ = [ "AtomExcludeMask", + "DistributedSameNlocBatchSampler", "EmbeddingNet", "EnvMat", "FittingNet", + "LmdbDataReader", + "LmdbTestData", "NativeLayer", "NativeNet", "NetworkCollection", "PairExcludeMask", + "SameNlocBatchSampler", "aggregate", "build_multiple_neighbor_list", "build_neighbor_list", @@ -59,10 +71,12 @@ "get_graph_index", "get_multiple_nlist_key", "inter2phys", + "is_lmdb", "load_dp_model", "make_embedding_network", "make_fitting_network", "make_multilayer_network", + "make_neighbor_stat_data", "nlist_distinguish_types", "normalize_coord", "phys2inter", diff --git a/deepmd/dpmodel/utils/lmdb_data.py b/deepmd/dpmodel/utils/lmdb_data.py new file mode 100644 index 0000000000..243d4f525d --- /dev/null +++ b/deepmd/dpmodel/utils/lmdb_data.py @@ -0,0 +1,1525 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Framework-agnostic LMDB data utilities for DeePMD-kit. + +All code here is pure Python/NumPy/lmdb/msgpack — no framework dependency. +Backend-specific wrappers (PyTorch Dataset, JAX, etc.) import from here. +""" + +import logging +import math +from collections.abc import ( + Iterator, +) +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import lmdb +import msgpack +import numpy as np + +from deepmd.env import ( + GLOBAL_ENER_FLOAT_PRECISION, + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + +log = logging.getLogger(__name__) + +# LMDB key → DeePMD convention +_KEY_REMAP = { + "coords": "coord", + "cells": "box", + "energies": "energy", + "forces": "force", + "atom_types": "atype", + "virials": "virial", +} + +# Keys whose high_prec is always True in the standard pipeline +# (energy is set by Loss DataRequirementItem; reduce() also sets high_prec=True) +_HIGH_PREC_KEYS = frozenset({"energy"}) + +# Process-level cache: python-lmdb does not allow opening the same path twice +# in one process. We ref-count so the Environment is closed (and freed from +# the cache) once every reader that shares it is garbage-collected. +_ENV_CACHE: dict[str, tuple[lmdb.Environment, int]] = {} + + +def _open_lmdb(path: str) -> lmdb.Environment: + """Open (or reuse) a readonly LMDB environment with reference counting. + + The python-lmdb binding raises ``lmdb.Error`` if the same path is opened + more than once in a single process. We cache by resolved absolute path + and bump a reference count. Call :func:`_close_lmdb` when done to + decrement the count; when it reaches zero the environment is closed and + removed from the cache. + """ + resolved = str(Path(path).resolve()) + entry = _ENV_CACHE.get(resolved) + if entry is not None: + env, refcount = entry + _ENV_CACHE[resolved] = (env, refcount + 1) + return env + env = lmdb.open(path, readonly=True, lock=False, readahead=False, meminit=False) + _ENV_CACHE[resolved] = (env, 1) + return env + + +def _close_lmdb(path: str) -> None: + """Decrement the ref-count for *path* and close the env when it hits zero.""" + resolved = str(Path(path).resolve()) + entry = _ENV_CACHE.get(resolved) + if entry is None: + return + env, refcount = entry + if refcount <= 1: + del _ENV_CACHE[resolved] + try: + env.close() + except Exception: + pass + else: + _ENV_CACHE[resolved] = (env, refcount - 1) + + +def _read_metadata(txn: lmdb.Transaction) -> dict: + """Read and decode __metadata__ from LMDB transaction.""" + raw = txn.get(b"__metadata__") + if raw is None: + raise ValueError("LMDB file missing __metadata__ key") + return msgpack.unpackb(raw, raw=False) + + +def _decode_array(obj: dict) -> np.ndarray: + """Reconstruct ndarray from msgpack-encoded dict with {type, shape, data}. + + Handles both string keys ("type", "data") and byte keys (b"type", b"data"). + """ + dtype_key = "type" if "type" in obj else b"type" + data_key = "data" if "data" in obj else b"data" + shape_key = "shape" if "shape" in obj else b"shape" + dtype = np.dtype(obj[dtype_key]) + data = obj[data_key] + if shape_key in obj: + shape = tuple(obj[shape_key]) + else: + shape = (len(data) // dtype.itemsize,) + return np.frombuffer(data, dtype=dtype).reshape(shape).copy() + + +def _is_encoded_array(val: Any) -> bool: + """Check if a value is a msgpack-encoded ndarray dict.""" + if not isinstance(val, dict): + return False + return ("data" in val and "type" in val) or (b"data" in val and b"type" in val) + + +def _decode_value(val: Any) -> Any: + """Decode a value: encoded array -> ndarray, list of encoded -> list of ndarray, else pass through.""" + if _is_encoded_array(val): + return _decode_array(val) + elif isinstance(val, list) and len(val) > 0 and _is_encoded_array(val[0]): + return [_decode_array(item) for item in val] + return val + + +def _decode_frame(raw_bytes: bytes) -> dict[str, Any]: + """Decode a msgpack-serialized frame into a dict of numpy arrays / scalars.""" + frame = msgpack.unpackb(raw_bytes, raw=False) + result = {} + for key, val in frame.items(): + result[key] = _decode_value(val) + return result + + +def _remap_keys(frame: dict[str, Any]) -> dict[str, Any]: + """Remap LMDB key names to DeePMD convention, pass through unknown keys.""" + out = {} + for k, v in frame.items(): + out[_KEY_REMAP.get(k, k)] = v + return out + + +def is_lmdb(systems: str) -> bool: + """Check if systems points to an LMDB dataset.""" + return systems.endswith(".lmdb") or Path(systems, "data.mdb").is_file() + + +def _parse_metadata(meta: dict) -> tuple[int, str, list[int]]: + """Parse LMDB metadata into (nframes, frame_fmt, natoms_per_type). + + Handles system_info as list or dict, and natoms as plain ints or encoded arrays. + """ + nframes = meta["nframes"] + frame_fmt = meta.get("frame_idx_fmt", "012d") + raw_sys_info = meta.get("system_info", {}) + + if isinstance(raw_sys_info, list): + sys_info = raw_sys_info[0] if raw_sys_info else {} + else: + sys_info = raw_sys_info + + raw_natoms = sys_info.get("natoms", []) + natoms_per_type = [] + for item in raw_natoms: + if _is_encoded_array(item): + natoms_per_type.append(int(_decode_array(item).item())) + else: + natoms_per_type.append(int(item)) + + return nframes, frame_fmt, natoms_per_type + + +def _scan_frame_nlocs( + env: lmdb.Environment, nframes: int, frame_fmt: str, fallback_natoms: int +) -> list[int]: + """Scan all frames to get per-frame atom count. + + Reads only the atom_types shape from msgpack without decoding array data. + """ + nlocs = [] + with env.begin() as txn: + for i in range(nframes): + key = format(i, frame_fmt).encode() + raw = txn.get(key) + if raw is not None: + frame_raw = msgpack.unpackb(raw, raw=False) + atype_raw = frame_raw.get("atom_types") + if isinstance(atype_raw, dict): + shape = atype_raw.get("shape") or atype_raw.get(b"shape") + if shape: + nlocs.append(int(shape[0])) + continue + nlocs.append(fallback_natoms) + return nlocs + + +def _compute_batch_size(nloc: int, rule: int) -> int: + """Compute batch_size for a given nloc using the auto rule.""" + bsi = rule // max(nloc, 1) + if bsi * nloc < rule: + bsi += 1 + return max(bsi, 1) + + +class LmdbDataReader: + """Framework-agnostic LMDB dataset reader. + + Reads LMDB frames and returns dicts of numpy arrays. + Backend-specific Dataset classes (PyTorch, JAX, etc.) wrap this. + + Datasets are typically mixed-nloc (frames with different atom counts). + The ``mixed_batch`` flag controls batching strategy: + + - ``mixed_batch=False`` (default, old format): each batch contains only + frames with the same nloc. A ``SameNlocBatchSampler`` groups frames + by nloc and yields same-nloc batches. Auto batch_size is computed + per-nloc-group. + - ``mixed_batch=True`` (new format): frames with different nloc can + coexist in one batch (requires padding + mask in collate_fn). + Currently raises ``NotImplementedError`` at collation time. + + Parameters + ---------- + lmdb_path : str + Path to the LMDB directory. + type_map : list[str] + Global type map from model config. + batch_size : int or str + Batch size. Supports int, "auto", "auto:N". + mixed_batch : bool + If True, allow different nloc in the same batch (future). + If False (default), enforce same-nloc-per-batch. + """ + + def __init__( + self, + lmdb_path: str, + type_map: list[str], + batch_size: int | str = "auto", + mixed_batch: bool = False, + ) -> None: + self.lmdb_path = str(lmdb_path) + self._type_map = type_map + self._env = _open_lmdb(self.lmdb_path) + self.mixed_batch = mixed_batch + + with self._env.begin() as txn: + meta = _read_metadata(txn) + + self.nframes, self._frame_fmt, self._natoms_per_type = _parse_metadata(meta) + self._natoms = sum(self._natoms_per_type) + self._ntypes = len(type_map) + + # Build type remapping if LMDB's type_map differs from model's type_map + lmdb_type_map = meta.get("type_map") + self._lmdb_type_map = lmdb_type_map + self._type_remap: np.ndarray | None = None + if lmdb_type_map is not None and list(lmdb_type_map) != list(type_map): + # Build remap: lmdb_type_idx -> model_type_idx + remap = np.empty(len(lmdb_type_map), dtype=np.int32) + for i, name in enumerate(lmdb_type_map): + if name not in type_map: + raise ValueError( + f"Element '{name}' in LMDB type_map {lmdb_type_map} " + f"not found in model type_map {type_map}" + ) + remap[i] = type_map.index(name) + self._type_remap = remap + log.info( + f"Type remapping: LMDB {lmdb_type_map} -> model {type_map}, " + f"remap={remap}" + ) + + # Persistent read-only transaction for __getitem__ (avoids per-read overhead). + # Safe because we use num_workers=0 in DataLoader. + self._txn = self._env.begin() + + # Scan per-frame nloc only when needed for same-nloc batching. + # For mixed_batch=True, skip the scan entirely (future: padding handles it). + if not mixed_batch: + # Fast path: use pre-computed frame_nlocs from metadata if available. + # Falls back to scanning each frame's atom_types shape (~10 us/frame). + meta_nlocs = meta.get("frame_nlocs") + if meta_nlocs is not None: + self._frame_nlocs = [int(n) for n in meta_nlocs] + else: + self._frame_nlocs = _scan_frame_nlocs( + self._env, self.nframes, self._frame_fmt, self._natoms + ) + self._nloc_groups: dict[int, list[int]] = {} + for idx, nloc in enumerate(self._frame_nlocs): + self._nloc_groups.setdefault(nloc, []).append(idx) + else: + self._frame_nlocs = [] + self._nloc_groups = {} + + # Parse frame_system_ids for auto_prob support + meta_sys_ids = meta.get("frame_system_ids") + if meta_sys_ids is not None: + self._frame_system_ids: list[int] | None = [int(s) for s in meta_sys_ids] + self._nsystems = max(self._frame_system_ids) + 1 + self._system_groups: dict[int, list[int]] = {} + for idx, sid in enumerate(self._frame_system_ids): + self._system_groups.setdefault(sid, []).append(idx) + self._system_nframes: list[int] = [ + len(self._system_groups.get(i, [])) for i in range(self._nsystems) + ] + else: + self._frame_system_ids = None + self._nsystems = 1 + self._system_groups = {0: list(range(self.nframes))} + self._system_nframes = [self.nframes] + + # Parse batch_size spec + self._auto_rule: int | None = None + if isinstance(batch_size, str): + if batch_size == "auto": + self._auto_rule = 32 + elif batch_size.startswith("auto:"): + self._auto_rule = int(batch_size.split(":")[1]) + else: + self._auto_rule = 32 + # Default batch_size uses first frame's nloc (for total_batch estimate) + self.batch_size = _compute_batch_size(self._natoms, self._auto_rule) + else: + self.batch_size = int(batch_size) + + # Data requirements tracking + self._data_requirements: dict[str, DataRequirementItem] = {} + + def _compute_natoms_vec(self, atype: np.ndarray) -> np.ndarray: + """Compute natoms_vec from a frame's atype array. + + Returns [nloc, nloc, count_type0, count_type1, ...] with length ntypes+2. + """ + nloc = len(atype) + counts = np.bincount(atype, minlength=self._ntypes)[: self._ntypes] + vec = np.empty(self._ntypes + 2, dtype=np.int64) + vec[0] = nloc + vec[1] = nloc + vec[2:] = counts + return vec + + def _resolve_dtype(self, key: str) -> np.dtype: + """Resolve the target numpy dtype for a given key. + + Priority: DataRequirementItem.dtype > DataRequirementItem.high_prec > + built-in defaults (energy=high, others=normal). + """ + if key in self._data_requirements: + req = self._data_requirements[key] + # Support both DataRequirementItem objects and plain dicts + if isinstance(req, dict): + dtype = req.get("dtype") + if dtype is not None: + return dtype + if req.get("high_prec", False): + return GLOBAL_ENER_FLOAT_PRECISION + return GLOBAL_NP_FLOAT_PRECISION + else: + # DataRequirementItem object + if hasattr(req, "dtype") and req.dtype is not None: + return req.dtype + if hasattr(req, "high_prec") and req.high_prec: + return GLOBAL_ENER_FLOAT_PRECISION + return GLOBAL_NP_FLOAT_PRECISION + # Fall back to built-in defaults + if key in _HIGH_PREC_KEYS: + return GLOBAL_ENER_FLOAT_PRECISION + return GLOBAL_NP_FLOAT_PRECISION + + def __del__(self) -> None: + """Release the LMDB environment ref-count on garbage collection.""" + path = getattr(self, "lmdb_path", None) + if path is not None: + _close_lmdb(path) + + def get_batch_size_for_nloc(self, nloc: int) -> int: + """Get batch_size for a given nloc. Uses auto rule if configured.""" + if self._auto_rule is not None: + return _compute_batch_size(nloc, self._auto_rule) + return self.batch_size + + def __len__(self) -> int: + return self.nframes + + def __getitem__(self, index: int) -> dict[str, Any]: + """Read frame from LMDB, decode, remap keys, return dict of numpy arrays.""" + key = format(index, self._frame_fmt).encode() + raw = self._txn.get(key) + if raw is None: + raise IndexError(f"Frame {index} not found in LMDB") + frame = _decode_frame(raw) + frame = _remap_keys(frame) + + # Remove LMDB-specific metadata keys not needed by trainer + for meta_key in ("atom_numbs", "atom_names", "orig"): + frame.pop(meta_key, None) + + # Flatten arrays to match DeePMD convention + if "coord" in frame and isinstance(frame["coord"], np.ndarray): + frame["coord"] = ( + frame["coord"].reshape(-1, 3).astype(self._resolve_dtype("coord")) + ) + if "box" in frame and isinstance(frame["box"], np.ndarray): + frame["box"] = frame["box"].reshape(9).astype(self._resolve_dtype("box")) + if "energy" in frame: + val = frame["energy"] + if isinstance(val, np.ndarray): + frame["energy"] = val.reshape(1).astype(self._resolve_dtype("energy")) + else: + frame["energy"] = np.array( + [float(val)], dtype=self._resolve_dtype("energy") + ) + if "force" in frame and isinstance(frame["force"], np.ndarray): + frame["force"] = ( + frame["force"].reshape(-1, 3).astype(self._resolve_dtype("force")) + ) + if "atype" in frame and isinstance(frame["atype"], np.ndarray): + frame["atype"] = frame["atype"].reshape(-1).astype(np.int64) + # Remap atom types from LMDB's type_map to model's type_map + if self._type_remap is not None: + frame["atype"] = self._type_remap[frame["atype"]].astype(np.int64) + if "virial" in frame and isinstance(frame["virial"], np.ndarray): + frame["virial"] = ( + frame["virial"].reshape(9).astype(self._resolve_dtype("virial")) + ) + + # Per-frame natoms_vec from atype + atype = frame.get("atype") + if atype is not None: + frame_natoms = len(atype) + natoms_vec = self._compute_natoms_vec(atype) + frame["natoms"] = natoms_vec + frame["real_natoms_vec"] = natoms_vec + else: + frame_natoms = self._natoms + fallback = np.array( + [self._natoms, self._natoms] + [0] * self._ntypes, dtype=np.int64 + ) + frame["natoms"] = fallback + frame["real_natoms_vec"] = fallback + + # Add find_* flags for all data keys present in the frame. + # Core structural keys and metadata are excluded — only label-like + # and auxiliary data keys get find_* flags. + _structural_keys = frozenset( + { + "coord", + "box", + "atype", + "natoms", + "real_natoms_vec", + "fid", + } + ) + for fk in list(frame.keys()): + if fk.startswith("find_") or fk in _structural_keys: + continue + # Skip keys handled by data_requirements (processed below) + if fk in self._data_requirements: + continue + if f"find_{fk}" not in frame: + frame[f"find_{fk}"] = np.float32(1.0) + + # Handle registered data requirements: fill defaults for missing keys, + # apply repeat, and cast dtype. + for req_key, req_item in self._data_requirements.items(): + # Extract requirement fields (support both dict and object) + if isinstance(req_item, dict): + ndof = req_item["ndof"] + default = req_item["default"] + atomic = req_item["atomic"] + repeat = req_item.get("repeat", 1) + req_dtype = req_item.get("dtype") + if req_dtype is None: + req_dtype = ( + GLOBAL_ENER_FLOAT_PRECISION + if req_item.get("high_prec", False) + else GLOBAL_NP_FLOAT_PRECISION + ) + else: + ndof = req_item.ndof + default = req_item.default + atomic = req_item.atomic + repeat = getattr(req_item, "repeat", 1) + req_dtype = req_item.dtype + if req_dtype is None: + req_dtype = ( + GLOBAL_ENER_FLOAT_PRECISION + if req_item.high_prec + else GLOBAL_NP_FLOAT_PRECISION + ) + + if req_key not in frame: + frame[f"find_{req_key}"] = np.float32(0.0) + if atomic: + shape = (frame_natoms, ndof) + else: + shape = (ndof,) + data = np.full(shape, default, dtype=req_dtype) + if repeat != 1: + data = np.repeat(data, repeat).reshape(-1) + frame[req_key] = data + else: + if f"find_{req_key}" not in frame: + frame[f"find_{req_key}"] = np.float32(1.0) + # Apply repeat to existing data (e.g. atom_pref repeat=3) + if repeat != 1 and isinstance(frame[req_key], np.ndarray): + frame[req_key] = ( + np.repeat(frame[req_key], repeat).reshape(-1).astype(req_dtype) + ) + + # Add find_* for fparam/aparam/spin if not already set + for extra_key in ["fparam", "aparam", "spin"]: + if f"find_{extra_key}" not in frame: + frame[f"find_{extra_key}"] = ( + np.float32(1.0) if extra_key in frame else np.float32(0.0) + ) + + frame["fid"] = index + + return frame + + # --- Data requirement interface --- + + def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: + """Register expected keys; missing keys get default fill + find_key=0.0.""" + for item in data_requirement: + self._data_requirements[item["key"]] = item + + def print_summary(self, name: str, prob: Any) -> None: + """Print basic dataset info.""" + n_groups = len(self._nloc_groups) + + log.info( + f"LMDB {name}: {self.lmdb_path}, " + f"{self.nframes} frames, {n_groups} nloc groups, " + f"batch_size={'auto' if self._auto_rule else self.batch_size}, " + f"mixed_batch={self.mixed_batch}" + ) + # Print nloc groups in rows of ~10 for readability + items = [ + f"{nloc}({len(idxs)})" for nloc, idxs in sorted(self._nloc_groups.items()) + ] + per_row = 10 + for i in range(0, len(items), per_row): + row = ", ".join(items[i : i + per_row]) + log.info(f" nloc groups: {row}") + + def set_noise(self, noise_settings: dict[str, Any]) -> None: + """No-op for now.""" + + # --- Properties --- + + @property + def index(self) -> list[int]: + """Number of batches per system (single system).""" + return [max(1, self.nframes // self.batch_size)] + + @property + def total_batch(self) -> int: + return self.index[0] + + @property + def batch_sizes(self) -> list[int]: + return [self.batch_size] + + @property + def mixed_type(self) -> bool: + """LMDB datasets are always mixed_type (frames may have different compositions).""" + return True + + @property + def nloc_groups(self) -> dict[int, list[int]]: + """Nloc → list of frame indices.""" + return self._nloc_groups + + @property + def frame_nlocs(self) -> list[int]: + """Per-frame atom count.""" + return self._frame_nlocs + + @property + def nsystems(self) -> int: + """Number of original systems merged into this LMDB.""" + return self._nsystems + + @property + def frame_system_ids(self) -> list[int] | None: + """Per-frame system index, or None if not available.""" + return self._frame_system_ids + + @property + def system_groups(self) -> dict[int, list[int]]: + """System index → list of frame indices.""" + return self._system_groups + + @property + def system_nframes(self) -> list[int]: + """Number of frames per system.""" + return self._system_nframes + + +def compute_block_targets( + auto_prob_style: str, + nsystems: int, + system_nframes: list[int], +) -> list[tuple[list[int], int]]: + """Compute target frame count per block from auto_prob config. + + Uses the same ``prob_sys_size_ext`` logic as the npy pipeline to parse + the ``auto_prob`` string, then converts per-system probabilities into + per-block target frame counts using the "max(frames/prob)" strategy. + + Parameters + ---------- + auto_prob_style : str + e.g. ``"prob_sys_size;0:3:0.5;3:10:0.5"`` + nsystems : int + Total number of systems in the LMDB. + system_nframes : list[int] + Number of frames per system. + + Returns + ------- + list[tuple[list[int], int]] + Each element is ``(system_indices_in_block, target_frame_count)``. + Returns empty list if no expansion is needed (all targets == actual). + """ + from deepmd.utils.data_system import ( + prob_sys_size_ext, + ) + + # Parse block definitions from the auto_prob string + # Format: "prob_sys_size;stt:end:weight;stt:end:weight;..." + block_str = auto_prob_style.split(";")[1:] + blocks: list[tuple[int, int, float]] = [] + for part in block_str: + stt, end, weight = part.split(":") + blocks.append((int(stt), int(end), float(weight))) + + # Compute per-system probabilities using the standard function + sys_probs = prob_sys_size_ext(auto_prob_style, nsystems, system_nframes) + + # Group systems by block, compute block-level frames and prob + block_info: list[tuple[list[int], int, float]] = [] # (sys_ids, frames, prob) + for stt, end, _weight in blocks: + sys_ids = list(range(stt, end)) + block_frames = sum(system_nframes[i] for i in sys_ids) + block_prob = sum(sys_probs[i] for i in sys_ids) + block_info.append((sys_ids, block_frames, block_prob)) + + # Step 1-2: total_target = ceil(max(block_frames / block_prob)) + ratios = [] + for sys_ids, block_frames, block_prob in block_info: + if block_prob > 0: + ratios.append(block_frames / block_prob) + else: + ratios.append(0.0) + total_target = math.ceil(max(ratios)) if ratios else 0 + + # Step 3: per-block target = round(total_target * block_prob) + result: list[tuple[list[int], int]] = [] + needs_expansion = False + for sys_ids, block_frames, block_prob in block_info: + target = round(total_target * block_prob) + target = max(target, block_frames) # never shrink + if target > block_frames: + needs_expansion = True + result.append((sys_ids, target)) + + if not needs_expansion: + return [] + + return result + + +def _expand_indices_by_blocks( + indices: list[int], + frame_system_ids: np.ndarray, + block_targets: list[tuple[list[int], int]], + rng: np.random.Generator, + _block_total_actual: list[int] | None = None, + _sid_to_blk_arr: np.ndarray | None = None, +) -> list[int]: + """Expand frame indices according to block targets. + + For each block, computes the proportional target for the subset of + indices belonging to that block (within the current nloc group), + then applies full-copy + remainder sampling. + + Parameters + ---------- + indices : list[int] + Frame indices in the current nloc group. + frame_system_ids : np.ndarray + Per-frame system id for the entire dataset (int64 array). + block_targets : list[tuple[list[int], int]] + Per-block (system_ids, total_target_frames). + rng : np.random.Generator + RNG for remainder sampling. + _block_total_actual : list[int] or None + Pre-computed total actual frame count per block (across all nloc + groups). When provided, avoids an O(N) scan of frame_system_ids. + _sid_to_blk_arr : np.ndarray or None + Pre-computed system-id to block-index lookup array. When provided, + avoids rebuilding the mapping for each call. + + Returns + ------- + list[int] + Expanded indices. + """ + n_blocks = len(block_targets) + + # Build sys_id -> block_idx lookup array + if _sid_to_blk_arr is None: + sys_to_block: dict[int, int] = {} + for blk_idx, (sys_ids, _target) in enumerate(block_targets): + for sid in sys_ids: + sys_to_block[sid] = blk_idx + max_sid = max(sys_to_block.keys()) + 1 if sys_to_block else 0 + _sid_to_blk_arr = np.full(max_sid, -1, dtype=np.int32) + for sid, blk in sys_to_block.items(): + _sid_to_blk_arr[sid] = blk + + # Partition indices by block using numpy for speed + idx_arr = np.asarray(indices, dtype=np.int64) + sid_arr = np.asarray(frame_system_ids, dtype=np.int64) + # Vectorized lookup: get block id for each index + idx_sids = sid_arr[idx_arr] + idx_blks = _sid_to_blk_arr[idx_sids] + + # Pre-compute block_total_actual if not provided + if _block_total_actual is None: + _block_total_actual = [] + for sys_ids, _ in block_targets: + total = sum(int(np.sum(sid_arr == sid)) for sid in sys_ids) + _block_total_actual.append(total) + + expanded_parts: list[np.ndarray] = [] + + # Unassigned indices + unassigned_mask = idx_blks == -1 + if np.any(unassigned_mask): + expanded_parts.append(idx_arr[unassigned_mask]) + + for blk_idx in range(n_blocks): + blk_mask = idx_blks == blk_idx + blk_idxs = idx_arr[blk_mask] + n_actual = len(blk_idxs) + if n_actual == 0: + continue + + _, block_total_target = block_targets[blk_idx] + block_total_act = _block_total_actual[blk_idx] + + # Proportional target for this nloc subset + if block_total_act > 0: + target = round(block_total_target * n_actual / block_total_act) + else: + target = n_actual + target = max(target, n_actual) # never shrink + + # Full copies + remainder + deficit = target - n_actual + if deficit <= 0: + expanded_parts.append(blk_idxs) + else: + full_copies = deficit // n_actual + remainder = deficit % n_actual + # Original + full copies + if full_copies > 0: + expanded_parts.append(np.tile(blk_idxs, 1 + full_copies)) + else: + expanded_parts.append(blk_idxs) + # Remainder: sample without replacement + if remainder > 0: + sampled = rng.choice(blk_idxs, size=remainder, replace=False) + expanded_parts.append(sampled) + + if expanded_parts: + return np.concatenate(expanded_parts).tolist() + return [] + + +def _build_all_batches( + reader: "LmdbDataReader", + shuffle: bool, + rng: np.random.Generator, + block_targets: list[tuple[list[int], int]] | None = None, +) -> list[list[int]]: + """Build the full list of same-nloc batches from the reader. + + This is the shared batch-construction logic used by both + SameNlocBatchSampler (single-GPU) and DistributedSameNlocBatchSampler. + + Parameters + ---------- + reader : LmdbDataReader + Provides nloc_groups and get_batch_size_for_nloc. + shuffle : bool + Whether to shuffle indices within each nloc group and + shuffle the final batch order. + rng : np.random.Generator + Random number generator (deterministic for reproducibility). + block_targets : list[tuple[list[int], int]] or None + Per-block (system_ids, target_frame_count) from compute_block_targets. + When provided, indices are expanded via full-copy + remainder sampling. + + Returns + ------- + list[list[int]] + Each inner list is a batch of frame indices, all with the same nloc. + """ + # Build per-group batches + group_batches: list[list[list[int]]] = [] + + # Pre-compute expensive objects once (avoids O(N) work per nloc group) + block_total_actual: list[int] | None = None + sid_arr: np.ndarray | None = None + sid_to_blk_arr: np.ndarray | None = None + if block_targets and reader.frame_system_ids is not None: + block_total_actual = [] + for sys_ids, _ in block_targets: + total = sum(reader.system_nframes[s] for s in sys_ids) + block_total_actual.append(total) + # Convert frame_system_ids to numpy once + sid_arr = np.array(reader.frame_system_ids, dtype=np.int64) + # Build sys_id -> block_idx lookup array once + sys_to_block: dict[int, int] = {} + for blk_idx, (sys_ids, _target) in enumerate(block_targets): + for sid in sys_ids: + sys_to_block[sid] = blk_idx + max_sid = max(sys_to_block.keys()) + 1 if sys_to_block else 0 + sid_to_blk_arr = np.full(max_sid, -1, dtype=np.int32) + for sid, blk in sys_to_block.items(): + sid_to_blk_arr[sid] = blk + + for nloc in sorted(reader.nloc_groups.keys()): + indices = list(reader.nloc_groups[nloc]) + # Expand indices by block targets if provided + if block_targets and sid_arr is not None: + indices = _expand_indices_by_blocks( + indices, + sid_arr, + block_targets, + rng, + _block_total_actual=block_total_actual, + _sid_to_blk_arr=sid_to_blk_arr, + ) + if shuffle: + rng.shuffle(indices) + bs = reader.get_batch_size_for_nloc(nloc) + batches = [] + for start in range(0, len(indices), bs): + batches.append(indices[start : start + bs]) + group_batches.append(batches) + + # Interleave groups round-robin + all_batches: list[list[int]] = [] + max_len = max(len(gb) for gb in group_batches) if group_batches else 0 + for i in range(max_len): + for gb in group_batches: + if i < len(gb): + all_batches.append(gb[i]) + + # Optionally shuffle the interleaved order + if shuffle: + rng.shuffle(all_batches) + + return all_batches + + +class SameNlocBatchSampler: + """Batch sampler that groups frames by nloc. + + For mixed-nloc datasets with mixed_batch=False: each batch contains only + frames with the same nloc. Within each nloc group, frames are shuffled. + Groups are interleaved round-robin so training sees diverse nloc values. + + When auto batch_size is used, batch_size is computed per-nloc-group. + + The sampler is deterministic: given the same seed, repeated calls to + ``__iter__`` produce the same batch sequence. + + Parameters + ---------- + reader : LmdbDataReader + The dataset reader (provides nloc_groups, get_batch_size_for_nloc). + shuffle : bool + Whether to shuffle within each nloc group each epoch. + seed : int or None + Random seed for reproducibility. + block_targets : list[tuple[list[int], int]] or None + Per-block expansion targets from compute_block_targets. + """ + + def __init__( + self, + reader: LmdbDataReader, + shuffle: bool = True, + seed: int | None = None, + block_targets: list[tuple[list[int], int]] | None = None, + ) -> None: + self._reader = reader + self._shuffle = shuffle + self._seed = seed + self._block_targets = block_targets + + def __iter__(self) -> Iterator[list[int]]: + """Yield batches of frame indices, all with the same nloc.""" + rng = np.random.default_rng(self._seed) + yield from _build_all_batches( + self._reader, self._shuffle, rng, self._block_targets + ) + + def __len__(self) -> int: + """Total number of batches across all nloc groups (estimated).""" + total = 0 + for nloc, indices in self._reader.nloc_groups.items(): + n = len(indices) + if self._block_targets and self._reader.frame_system_ids is not None: + # Estimate expanded count for this nloc group + n = self._estimate_expanded_count(indices) + bs = self._reader.get_batch_size_for_nloc(nloc) + total += (n + bs - 1) // bs + return total + + def _estimate_expanded_count(self, indices: list[int]) -> int: + """Estimate expanded index count for __len__ without RNG.""" + if not self._block_targets or self._reader.frame_system_ids is None: + return len(indices) + sys_ids = self._reader.frame_system_ids + total = 0 + for blk_idx, (blk_sys_ids, block_target) in enumerate(self._block_targets): + blk_sys_set = set(blk_sys_ids) + n_in_nloc = sum(1 for i in indices if sys_ids[i] in blk_sys_set) + if n_in_nloc == 0: + continue + block_total_actual = sum(1 for sid in sys_ids if sid in blk_sys_set) + if block_total_actual > 0: + target = round(block_target * n_in_nloc / block_total_actual) + else: + target = n_in_nloc + total += max(target, n_in_nloc) + # Add unassigned + all_sys = set() + for blk_sys_ids, _ in self._block_targets: + all_sys.update(blk_sys_ids) + total += sum(1 for i in indices if sys_ids[i] not in all_sys) + return total + + +class DistributedSameNlocBatchSampler: + """Distributed wrapper for same-nloc batch sampling. + + All ranks build the same deterministic global batch list (using + ``seed + epoch``), then each rank takes a disjoint subset via + :meth:`_partition_batches`. + + Override :meth:`_partition_batches` for custom load-balancing strategies. + The default uses strided partitioning which gives good nloc diversity per + rank. + + Parameters + ---------- + reader : LmdbDataReader + The dataset reader (provides nloc_groups, get_batch_size_for_nloc, + frame_nlocs). + rank : int + Rank of the current process. + world_size : int + Total number of processes. + shuffle : bool + Whether to shuffle batches. + seed : int or None + Base seed for deterministic RNG. All ranks must use the same seed. + block_targets : list[tuple[list[int], int]] or None + Per-block expansion targets from compute_block_targets. + """ + + def __init__( + self, + reader: LmdbDataReader, + rank: int, + world_size: int, + shuffle: bool = True, + seed: int | None = None, + block_targets: list[tuple[list[int], int]] | None = None, + ) -> None: + self._reader = reader + self._rank = rank + self._world_size = world_size + self._shuffle = shuffle + self._seed = seed if seed is not None else 0 + self._epoch = 0 + self._block_targets = block_targets + + def set_epoch(self, epoch: int) -> None: + """Set epoch for deterministic cross-rank shuffling. + + Call this before each training epoch/cycle to get different but + reproducible batch orderings across epochs. + """ + self._epoch = epoch + + def __iter__(self) -> Iterator[list[int]]: + """Yield this rank's partition of the global batch list.""" + # All ranks build the same global batch list deterministically + rng = np.random.default_rng(self._seed + self._epoch) + all_batches = _build_all_batches( + self._reader, self._shuffle, rng, self._block_targets + ) + # Partition to this rank + yield from self._partition_batches(all_batches) + + def _partition_batches(self, all_batches: list[list[int]]) -> list[list[int]]: + """Partition global batches to this rank. + + Default: strided partition ``all_batches[rank::world_size]``. + This gives good nloc diversity per rank since batches are + interleaved across nloc groups before shuffling. + + Override this method for custom load-balancing. For example, a + greedy algorithm could assign batches to ranks based on estimated + compute cost (``reader.frame_nlocs[batch[0]]`` gives the nloc of + each batch). + """ + return all_batches[self._rank :: self._world_size] + + def __len__(self) -> int: + """Number of batches for this rank.""" + total = 0 + for nloc, indices in self._reader.nloc_groups.items(): + bs = self._reader.get_batch_size_for_nloc(nloc) + total += (len(indices) + bs - 1) // bs + return math.ceil(total / self._world_size) + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + +def make_neighbor_stat_data( + lmdb_path: str, + type_map: list[str] | None, + max_frames: int = 2000, +) -> Any: + """Create a duck-typed DeepmdDataSystem-like object for neighbor stat from LMDB. + + Samples up to *max_frames* frames, groups them by nloc, and returns an + object whose attributes satisfy the interface expected by + ``NeighborStat.iterator()`` and ``UpdateSel.get_nbor_stat()``. + """ + from types import ( + SimpleNamespace, + ) + + reader = LmdbDataReader(lmdb_path, type_map=type_map) + nframes = len(reader) + rng = np.random.RandomState(42) + if nframes > max_frames: + indices = np.sort(rng.choice(nframes, max_frames, replace=False)) + else: + indices = np.arange(nframes, dtype=np.int64) + + # Read sampled frames, group by nloc + nloc_frames: dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray | None]]] = {} + for idx in indices: + frame = reader[int(idx)] + atype = frame["atype"] + nloc = len(atype) + nloc_frames.setdefault(nloc, []).append( + (frame["coord"], atype, frame.get("box")) + ) + + # Build per-nloc data_system proxies + data_systems = [] + system_dirs: list[str] = [] + for nloc, frames in nloc_frames.items(): + coords = np.stack([c.reshape(nloc * 3) for c, _, _ in frames]) + types = np.stack([a.reshape(nloc) for _, a, _ in frames]) + has_box = frames[0][2] is not None + boxes = np.stack([b.reshape(9) for _, _, b in frames]) if has_box else None + set_data = {"coord": coords, "type": types, "box": boxes} + label = f"lmdb:{nloc}atoms" + proxy = SimpleNamespace( + dirs=[label], + pbc=has_box, + mixed_type=True, + get_natoms=lambda _nloc=nloc: _nloc, + _load_set=lambda _d, _sd=set_data: _sd, + ) + data_systems.append(proxy) + system_dirs.append(label) + + ntypes = len(type_map) if type_map else reader._ntypes + return SimpleNamespace( + system_dirs=system_dirs, + data_systems=data_systems, + get_batch=lambda: None, + get_ntypes=lambda: ntypes, + mixed_type=True, + ) + + +class LmdbTestData: + """LMDB-backed data reader for dp test. + + Mimics the DeepmdData interface used by test_ener(): + .add(), .get_test(), .mixed_type, .pbc + + For mixed-nloc datasets, frames are grouped by nloc. + get_test(nloc=...) returns data for a specific group. + """ + + def __init__( + self, + lmdb_path: str, + type_map: list[str] | None = None, + shuffle_test: bool = True, + **kwargs: Any, + ) -> None: + self.lmdb_path = str(lmdb_path) + self._type_map = type_map or [] + self._env = _open_lmdb(self.lmdb_path) + + with self._env.begin() as txn: + meta = _read_metadata(txn) + + self.nframes, self._frame_fmt, self._natoms_per_type = _parse_metadata(meta) + self._natoms = sum(self._natoms_per_type) + + # Build type remapping if LMDB's type_map differs from model's type_map + lmdb_type_map = meta.get("type_map") + self._lmdb_type_map = lmdb_type_map + self._type_remap: np.ndarray | None = None + if ( + lmdb_type_map is not None + and self._type_map + and list(lmdb_type_map) != list(self._type_map) + ): + remap = np.empty(len(lmdb_type_map), dtype=np.int32) + for i, name in enumerate(lmdb_type_map): + if name not in self._type_map: + raise ValueError( + f"Element '{name}' in LMDB type_map {lmdb_type_map} " + f"not found in model type_map {self._type_map}" + ) + remap[i] = self._type_map.index(name) + self._type_remap = remap + log.info( + f"LmdbTestData type remapping: LMDB {lmdb_type_map} -> " + f"model {self._type_map}, remap={remap}" + ) + + # Read all frames + self._frames: list[dict[str, Any]] = [] + with self._env.begin() as txn: + for i in range(self.nframes): + key = format(i, self._frame_fmt).encode() + raw = txn.get(key) + if raw is not None: + frame = _remap_keys(_decode_frame(raw)) + # Apply type remapping to atype + if ( + self._type_remap is not None + and "atype" in frame + and isinstance(frame["atype"], np.ndarray) + ): + frame["atype"] = self._type_remap[ + frame["atype"].reshape(-1) + ].astype(np.int64) + self._frames.append(frame) + + # Shuffle if requested + if shuffle_test: + rng = np.random.default_rng() + indices = rng.permutation(len(self._frames)) + self._frames = [self._frames[i] for i in indices] + + # Group frames by nloc + self._nloc_groups: dict[int, list[int]] = {} + for idx, frame in enumerate(self._frames): + atype = frame.get("atype") + nloc = len(atype) if isinstance(atype, np.ndarray) else self._natoms + self._nloc_groups.setdefault(nloc, []).append(idx) + + # Data requirements + self._requirements: dict[str, dict[str, Any]] = {} + + # Detect PBC: if any frame has a non-zero box + self.pbc = True + if len(self._frames) > 0: + f0 = self._frames[0] + if "box" not in f0: + self.pbc = False + elif isinstance(f0["box"], np.ndarray) and np.allclose(f0["box"], 0.0): + self.pbc = False + + self.mixed_type = True + + def __del__(self) -> None: + """Release the LMDB environment ref-count on garbage collection.""" + path = getattr(self, "lmdb_path", None) + if path is not None: + _close_lmdb(path) + + @property + def nloc_groups(self) -> dict[int, list[int]]: + """Nloc → list of frame indices in self._frames.""" + return self._nloc_groups + + def add( + self, + key: str, + ndof: int, + atomic: bool = False, + must: bool = True, + high_prec: bool = False, + repeat: int = 1, + default: float = 0.0, + dtype: np.dtype | None = None, + **kwargs: Any, + ) -> None: + """Register a data requirement (mirrors DeepmdData.add).""" + self._requirements[key] = { + "ndof": ndof, + "atomic": atomic, + "must": must, + "high_prec": high_prec, + "repeat": repeat, + "default": default, + "dtype": dtype, + } + + def _resolve_dtype(self, key: str) -> np.dtype: + """Resolve target dtype for a key using registered requirements.""" + if key in self._requirements: + req = self._requirements[key] + dtype = req.get("dtype") + if dtype is not None: + return dtype + if req.get("high_prec", False): + return GLOBAL_ENER_FLOAT_PRECISION + return GLOBAL_NP_FLOAT_PRECISION + if key in _HIGH_PREC_KEYS: + return GLOBAL_ENER_FLOAT_PRECISION + return GLOBAL_NP_FLOAT_PRECISION + + def get_test(self, nloc: int | None = None) -> dict[str, Any]: + """Return frames stacked as numpy arrays. + + Parameters + ---------- + nloc : int or None + If specified, return only frames with this atom count. + If None and all frames have the same nloc, return all. + If None and mixed nloc, return the largest group and log a warning. + Returns dict matching DeepmdData.get_test() format: + """ + if nloc is not None: + if nloc not in self._nloc_groups: + raise ValueError( + f"No frames with nloc={nloc}. Available: {sorted(self._nloc_groups.keys())}" + ) + frame_indices = self._nloc_groups[nloc] + natoms = nloc + elif len(self._nloc_groups) == 1: + # Uniform nloc — use all frames + natoms = next(iter(self._nloc_groups)) + frame_indices = list(range(len(self._frames))) + else: + # Mixed nloc — use the largest group + natoms = max(self._nloc_groups, key=lambda k: len(self._nloc_groups[k])) + frame_indices = self._nloc_groups[natoms] + group_summary = {k: len(v) for k, v in sorted(self._nloc_groups.items())} + log.warning( + f"Mixed-nloc LMDB for dp test: using nloc={natoms} group " + f"({len(frame_indices)} frames). " + f"Available groups: {group_summary}" + ) + + frames = [self._frames[i] for i in frame_indices] + return self._stack_frames(frames, natoms) + + def _stack_frames( + self, frames: list[dict[str, Any]], natoms: int + ) -> dict[str, Any]: + """Stack a list of same-nloc frames into numpy arrays.""" + nframes = len(frames) + result: dict[str, Any] = {} + + # Core arrays + coords = [] + boxes = [] + atypes = [] + + for frame in frames: + if "coord" in frame and isinstance(frame["coord"], np.ndarray): + coords.append( + frame["coord"] + .reshape(natoms * 3) + .astype(self._resolve_dtype("coord")) + ) + if "box" in frame and isinstance(frame["box"], np.ndarray): + boxes.append(frame["box"].reshape(9).astype(self._resolve_dtype("box"))) + else: + boxes.append(np.zeros(9, dtype=self._resolve_dtype("box"))) + if "atype" in frame and isinstance(frame["atype"], np.ndarray): + atypes.append(frame["atype"].reshape(natoms).astype(np.int64)) + + result["coord"] = ( + np.stack(coords) + if coords + else np.zeros((0, natoms * 3), dtype=self._resolve_dtype("coord")) + ) + result["box"] = ( + np.stack(boxes) + if boxes + else np.zeros((0, 9), dtype=self._resolve_dtype("box")) + ) + result["type"] = ( + np.stack(atypes) if atypes else np.zeros((0, natoms), dtype=np.int64) + ) + + # Dynamically discover all data keys from the first frame, plus + # any registered requirements. Structural keys (coord, box, type) + # are excluded — they are already handled above. + _structural_keys = frozenset({"coord", "box", "atype"}) + all_keys: dict[str, dict[str, Any]] = {} + if frames: + for fk in frames[0]: + if fk in _structural_keys or fk.startswith("find_"): + continue + if fk not in all_keys: + all_keys[fk] = {"ndof": None, "atomic": False, "default": 0.0} + for key, req in self._requirements.items(): + all_keys[key] = req + + for key, req_info in all_keys.items(): + has_key = any( + key in f and isinstance(f.get(key), np.ndarray) for f in frames + ) + result[f"find_{key}"] = 1.0 if has_key else 0.0 + + # Get repeat factor from registered requirements + repeat = 1 + if key in self._requirements: + repeat = self._requirements[key].get("repeat", 1) + + if has_key: + arrays = [] + for frame in frames: + val = frame.get(key) + if isinstance(val, np.ndarray): + arr = val.astype(self._resolve_dtype(key)).ravel() + if repeat != 1: + arr = np.repeat(arr, repeat) + arrays.append(arr) + elif val is not None: + arrays.append( + np.array([float(val)], dtype=self._resolve_dtype(key)) + ) + else: + ref = next( + ( + f[key] + for f in frames + if isinstance(f.get(key), np.ndarray) + ), + None, + ) + if ref is not None: + size = ref.size * repeat if repeat != 1 else ref.size + arrays.append( + np.zeros(size, dtype=self._resolve_dtype(key)) + ) + else: + arrays.append(np.zeros(1, dtype=self._resolve_dtype(key))) + result[key] = np.stack(arrays) + elif key in self._requirements: + ndof = self._requirements[key]["ndof"] + atomic = self._requirements[key]["atomic"] + default = self._requirements[key]["default"] + if atomic: + shape = (nframes, natoms * ndof * repeat) + else: + shape = (nframes, ndof * repeat) + result[key] = np.full(shape, default, dtype=self._resolve_dtype(key)) + + return result + + +def merge_lmdb( + src_paths: list[str], + dst_path: str, + *, + map_size: int = 1024**4, # 1 TB default +) -> str: + """Merge multiple LMDB datasets into one. + + Frames are concatenated in order. The output metadata includes a + ``frame_nlocs`` list for fast init (skips per-frame scan). + + Parameters + ---------- + src_paths : list[str] + Paths to source LMDB directories. + dst_path : str + Path for the merged LMDB output. + map_size : int + Maximum size of the output LMDB (default 1 TB). + + Returns + ------- + str + Path to the created LMDB. + """ + import os + import shutil + + if os.path.exists(dst_path): + shutil.rmtree(dst_path) + + dst_env = lmdb.open(dst_path, map_size=map_size) + frame_idx = 0 + fmt = "012d" + frame_nlocs: list[int] = [] + frame_system_ids: list[int] = [] + first_system_info: dict | None = None + first_type_map: list[str] | None = None + sys_id_offset = 0 + + for src_path in src_paths: + src_env = _open_lmdb(src_path) + with src_env.begin() as txn: + meta = _read_metadata(txn) + nframes, src_fmt, natoms_per_type = _parse_metadata(meta) + fallback_natoms = sum(natoms_per_type) + + if first_system_info is None: + first_system_info = meta.get("system_info", {}) + if first_type_map is None: + first_type_map = meta.get("type_map") + + # Check for pre-computed frame_nlocs in source + src_nlocs = meta.get("frame_nlocs") + # Check for frame_system_ids in source + src_sys_ids = meta.get("frame_system_ids") + + with src_env.begin() as src_txn, dst_env.begin(write=True) as dst_txn: + for i in range(nframes): + src_key = format(i, src_fmt).encode() + raw = src_txn.get(src_key) + if raw is None: + continue + dst_key = format(frame_idx, fmt).encode() + dst_txn.put(dst_key, raw) + + # Get nloc for this frame + if src_nlocs is not None: + frame_nlocs.append(int(src_nlocs[i])) + else: + frame_raw = msgpack.unpackb(raw, raw=False) + atype_raw = frame_raw.get("atom_types") + if isinstance(atype_raw, dict): + shape = atype_raw.get("shape") or atype_raw.get(b"shape") + if shape: + frame_nlocs.append(int(shape[0])) + else: + frame_nlocs.append(fallback_natoms) + else: + frame_nlocs.append(fallback_natoms) + + # Propagate system IDs with offset + if src_sys_ids is not None and i < len(src_sys_ids): + frame_system_ids.append(int(src_sys_ids[i]) + sys_id_offset) + else: + frame_system_ids.append(sys_id_offset) + + frame_idx += 1 + + # Update sys_id_offset for next source + if src_sys_ids is not None and len(src_sys_ids) > 0: + sys_id_offset += max(int(s) for s in src_sys_ids) + 1 + else: + sys_id_offset += 1 + + src_env.close() + + # Write merged metadata with frame_nlocs for fast init + merged_meta = { + "nframes": frame_idx, + "frame_idx_fmt": fmt, + "system_info": first_system_info or {}, + "frame_nlocs": frame_nlocs, + "frame_system_ids": frame_system_ids, + } + if first_type_map is not None: + merged_meta["type_map"] = first_type_map + with dst_env.begin(write=True) as txn: + txn.put(b"__metadata__", msgpack.packb(merged_meta, use_bin_type=True)) + dst_env.close() + + nloc_counts: dict[int, int] = {} + for n in frame_nlocs: + nloc_counts[n] = nloc_counts.get(n, 0) + 1 + log.info( + f"Merged {len(src_paths)} LMDBs → {dst_path}: " + f"{frame_idx} frames, nloc groups: {dict(sorted(nloc_counts.items()))}" + ) + return dst_path diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 62de61ac6a..9a5f2f3f79 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -13,9 +13,12 @@ import numpy as np from deepmd.common import ( - expand_sys_str, j_loader, ) +from deepmd.dpmodel.utils.lmdb_data import ( + LmdbTestData, + is_lmdb, +) from deepmd.infer.deep_dipole import ( DeepDipole, ) @@ -62,6 +65,24 @@ log = logging.getLogger(__name__) +class _LmdbTestDataNlocView: + """Thin wrapper that makes LmdbTestData.get_test() return a specific nloc group. + + Delegates all attributes to the underlying LmdbTestData, but get_test() + returns only frames with the specified nloc. + """ + + def __init__(self, lmdb_test_data: LmdbTestData, nloc: int) -> None: + self._inner = lmdb_test_data + self._nloc = nloc + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + def get_test(self) -> dict: + return self._inner.get_test(nloc=self._nloc) + + def test( *, model: str, @@ -146,7 +167,7 @@ def test( with open(datafile) as datalist: all_sys = datalist.read().splitlines() elif system is not None: - all_sys = expand_sys_str(system) + all_sys = process_systems(system) else: raise RuntimeError("No data source specified for testing") @@ -168,61 +189,92 @@ def test( # create data class tmap = dp.get_type_map() - data = DeepmdData( - system, - set_prefix="set", - shuffle_test=shuffle_test, - type_map=tmap, - sort_atoms=False, - ) - - if isinstance(dp, DeepPot): - err = test_ener( - dp, - data, + if is_lmdb(system): + lmdb_data = LmdbTestData( system, - numb_test, - detail_file, - atomic, - append_detail=(cc != 0), + type_map=tmap, + shuffle_test=shuffle_test, ) - elif isinstance(dp, DeepDOS): - err = test_dos( - dp, - data, - system, - numb_test, - detail_file, - atomic, - append_detail=(cc != 0), - ) - elif isinstance(dp, DeepProperty): - err = test_property( - dp, - data, + # For mixed-nloc LMDB, test each nloc group separately + nloc_keys = sorted(lmdb_data.nloc_groups.keys()) + if len(nloc_keys) > 1: + group_summary = { + k: len(v) for k, v in sorted(lmdb_data.nloc_groups.items()) + } + log.info( + f"# mixed-nloc LMDB: testing {len(nloc_keys)} groups: " + f"{group_summary}" + ) + data_items: list[tuple[Any, str]] = [] + for nloc_val in nloc_keys: + label = f"{system} [nloc={nloc_val}]" if len(nloc_keys) > 1 else system + # Create a thin wrapper that returns only this nloc group + data_items.append((_LmdbTestDataNlocView(lmdb_data, nloc_val), label)) + else: + data = DeepmdData( system, - numb_test, - detail_file, - atomic, - append_detail=(cc != 0), - ) - elif isinstance(dp, DeepDipole): - err = test_dipole(dp, data, numb_test, detail_file, atomic) - elif isinstance(dp, DeepPolar): - err = test_polar(dp, data, numb_test, detail_file, atomic=atomic) - elif isinstance(dp, DeepGlobalPolar): # should not appear in this new version - log.warning( - "Global polar model is not currently supported. Please directly use the polar mode and change loss parameters." + set_prefix="set", + shuffle_test=shuffle_test, + type_map=tmap, + sort_atoms=False, ) - err = test_polar( - dp, data, numb_test, detail_file, atomic=False - ) # YWolfeee: downward compatibility - log.info("# ----------------------------------------------- ") - err_coll.append(err) + data_items = [(data, system)] + + for data, sys_label in data_items: + if sys_label != system: + log.info(f"# testing sub-group : {sys_label}") + + if isinstance(dp, DeepPot): + err = test_ener( + dp, + data, + sys_label, + numb_test, + detail_file, + atomic, + append_detail=(cc != 0), + ) + elif isinstance(dp, DeepDOS): + err = test_dos( + dp, + data, + sys_label, + numb_test, + detail_file, + atomic, + append_detail=(cc != 0), + ) + elif isinstance(dp, DeepProperty): + err = test_property( + dp, + data, + sys_label, + numb_test, + detail_file, + atomic, + append_detail=(cc != 0), + ) + elif isinstance(dp, DeepDipole): + err = test_dipole(dp, data, numb_test, detail_file, atomic) + elif isinstance(dp, DeepPolar): + err = test_polar(dp, data, numb_test, detail_file, atomic=atomic) + elif isinstance( + dp, DeepGlobalPolar + ): # should not appear in this new version + log.warning( + "Global polar model is not currently supported. Please directly use the polar mode and change loss parameters." + ) + err = test_polar( + dp, data, numb_test, detail_file, atomic=False + ) # YWolfeee: downward compatibility + log.info("# ----------------------------------------------- ") + err_coll.append(err) avg_err = weighted_average(err_coll) - if len(all_sys) != len(err_coll): + # For mixed-nloc LMDB, err_coll may have more entries than all_sys + # (one per nloc group per system). Only warn if fewer. + if len(err_coll) < len(all_sys): log.warning("Not all systems are tested! Check if the systems are valid") log.info("# ----------weighted average of errors----------- ") diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 46ad8a6cd0..7b45c46333 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -70,6 +70,10 @@ from deepmd.pt.utils.finetune import ( get_finetune_rules, ) +from deepmd.pt.utils.lmdb_dataset import ( + LmdbDataset, + is_lmdb, +) from deepmd.pt.utils.multi_task import ( preprocess_shared_params, ) @@ -114,7 +118,9 @@ def prepare_trainer_input_single( data_dict_single: dict[str, Any], rank: int = 0, seed: int | None = None, - ) -> tuple[DpLoaderSet, DpLoaderSet | None, DPPath | None]: + ) -> tuple[ + DpLoaderSet | LmdbDataset, DpLoaderSet | LmdbDataset | None, DPPath | None + ]: # get data modifier modifier = None modifier_params = model_params_single.get("modifier", None) @@ -127,11 +133,6 @@ def prepare_trainer_input_single( validation_dataset_params["systems"] if validation_dataset_params else None ) training_systems = training_dataset_params["systems"] - trn_patterns = training_dataset_params.get("rglob_patterns", None) - training_systems = process_systems(training_systems, patterns=trn_patterns) - if validation_systems is not None: - val_patterns = validation_dataset_params.get("rglob_patterns", None) - validation_systems = process_systems(validation_systems, val_patterns) # stat files stat_file_path_single = data_dict_single.get("stat_file", None) @@ -146,27 +147,58 @@ def prepare_trainer_input_single( Path(stat_file_path_single).mkdir() stat_file_path_single = DPPath(stat_file_path_single, "a") - # validation and training data - # avoid the same batch sequence among devices rank_seed = [rank, seed % (2**32)] if seed is not None else None - validation_data_single = ( - DpLoaderSet( - validation_systems, - validation_dataset_params["batch_size"], + + def _make_dp_loader_set( + systems: str | list[str], + dataset_params: dict[str, Any], + ) -> DpLoaderSet: + """Create a DpLoaderSet from systems with pattern expansion.""" + patterns = dataset_params.get("rglob_patterns", None) + systems = process_systems(systems, patterns=patterns) + return DpLoaderSet( + systems, + dataset_params["batch_size"], model_params_single["type_map"], seed=rank_seed, modifier=modifier, ) - if validation_systems - else None - ) - train_data_single = DpLoaderSet( - training_systems, - training_dataset_params["batch_size"], - model_params_single["type_map"], - seed=rank_seed, - modifier=modifier, - ) + + # LMDB path: single string → LmdbDataset + if isinstance(training_systems, str) and is_lmdb(training_systems): + auto_prob = training_dataset_params.get("auto_prob", None) + train_data_single = LmdbDataset( + training_systems, + model_params_single["type_map"], + training_dataset_params["batch_size"], + auto_prob_style=auto_prob, + ) + if ( + validation_systems is not None + and isinstance(validation_systems, str) + and is_lmdb(validation_systems) + ): + validation_data_single = LmdbDataset( + validation_systems, + model_params_single["type_map"], + validation_dataset_params["batch_size"], + ) + elif validation_systems is not None: + validation_data_single = _make_dp_loader_set( + validation_systems, validation_dataset_params + ) + else: + validation_data_single = None + else: + # Standard npy path + train_data_single = _make_dp_loader_set( + training_systems, training_dataset_params + ) + validation_data_single = ( + _make_dp_loader_set(validation_systems, validation_dataset_params) + if validation_systems + else None + ) return ( train_data_single, validation_data_single, @@ -336,9 +368,21 @@ def train( if not multi_task: type_map = config["model"].get("type_map") - train_data = get_data( - config["training"]["training_data"], 0, type_map, None - ) + training_systems = config["training"]["training_data"].get("systems") + if ( + training_systems is not None + and isinstance(training_systems, str) + and is_lmdb(training_systems) + ): + from deepmd.dpmodel.utils.lmdb_data import ( + make_neighbor_stat_data, + ) + + train_data = make_neighbor_stat_data(training_systems, type_map) + else: + train_data = get_data( + config["training"]["training_data"], 0, type_map, None + ) config["model"], min_nbor_dist = BaseModel.update_sel( train_data, type_map, config["model"] ) @@ -346,12 +390,26 @@ def train( min_nbor_dist = {} for model_item in config["model"]["model_dict"]: type_map = config["model"]["model_dict"][model_item].get("type_map") - train_data = get_data( - config["training"]["data_dict"][model_item]["training_data"], - 0, - type_map, - None, - ) + training_systems = config["training"]["data_dict"][model_item][ + "training_data" + ].get("systems") + if ( + training_systems is not None + and isinstance(training_systems, str) + and is_lmdb(training_systems) + ): + from deepmd.dpmodel.utils.lmdb_data import ( + make_neighbor_stat_data, + ) + + train_data = make_neighbor_stat_data(training_systems, type_map) + else: + train_data = get_data( + config["training"]["data_dict"][model_item]["training_data"], + 0, + type_map, + None, + ) config["model"]["model_dict"][model_item], min_nbor_dist[model_item] = ( BaseModel.update_sel( train_data, type_map, config["model"]["model_dict"][model_item] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 8d16e1c7ea..f4e836983f 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -74,6 +74,11 @@ from deepmd.pt.utils.learning_rate import ( BaseLR, ) +from deepmd.pt.utils.lmdb_dataset import ( + LmdbDataset, + _collate_lmdb_batch, + _SameNlocBatchSamplerTorch, +) from deepmd.pt.utils.stat import ( make_stat_input, ) @@ -218,8 +223,8 @@ def cycle_iterator(iterable: Iterable) -> Generator[Any, None, None]: yield from it def get_data_loader( - _training_data: DpLoaderSet, - _validation_data: DpLoaderSet | None, + _training_data: DpLoaderSet | LmdbDataset, + _validation_data: DpLoaderSet | LmdbDataset | None, _training_params: dict[str, Any], ) -> tuple[ DataLoader, @@ -228,6 +233,62 @@ def get_data_loader( Generator[Any, None, None] | None, int, ]: + def get_dataloader_and_iter_lmdb( + _data: LmdbDataset, + ) -> tuple[DataLoader, Generator[Any, None, None]]: + if _data.mixed_batch: + # TODO [mixed_batch=True]: Replace SameNlocBatchSampler with + # RandomSampler(replacement=False) + padding collate_fn. + # Changes needed: + # 1. _collate_lmdb_batch: pad coord/force/atype to max_nloc, + # add "atom_mask" bool tensor (nframes, max_nloc) + # 2. Use RandomSampler(_data, replacement=False) as sampler + # 3. Use fixed batch_size in DataLoader (not batch_sampler) + # 4. Model forward: apply atom_mask to descriptor/fitting + # 5. Loss: mask out padded atoms in force loss + raise NotImplementedError( + "mixed_batch=True training is not yet supported." + ) + # mixed_batch=False: group frames by nloc, each batch same nloc. + # SameNlocBatchSampler yields list[int] per batch, all same nloc. + # Auto batch_size is computed per-nloc-group inside the sampler. + from deepmd.dpmodel.utils.lmdb_data import ( + SameNlocBatchSampler, + ) + + _block_targets = getattr(_data, "_block_targets", None) + + if self.world_size > 1: + from deepmd.dpmodel.utils.lmdb_data import ( + DistributedSameNlocBatchSampler, + ) + + _inner_sampler = DistributedSameNlocBatchSampler( + _data._reader, + rank=self.rank, + world_size=self.world_size, + shuffle=True, + seed=_training_params.get("seed", None), + block_targets=_block_targets, + ) + else: + _inner_sampler = SameNlocBatchSampler( + _data._reader, + shuffle=True, + block_targets=_block_targets, + ) + + _batch_sampler = _SameNlocBatchSamplerTorch(_inner_sampler) + _dataloader = DataLoader( + _data, + batch_sampler=_batch_sampler, + num_workers=0, + collate_fn=_collate_lmdb_batch, + pin_memory=(DEVICE != "cpu"), + ) + _data_iter = cycle_iterator(_dataloader) + return _dataloader, _data_iter + def get_dataloader_and_iter( _data: DpLoaderSet, _params: dict[str, Any] ) -> tuple[DataLoader, Generator[Any, None, None]]: @@ -250,17 +311,28 @@ def get_dataloader_and_iter( _data_iter = cycle_iterator(_dataloader) return _dataloader, _data_iter - training_dataloader, training_data_iter = get_dataloader_and_iter( - _training_data, _training_params["training_data"] - ) + if isinstance(_training_data, LmdbDataset): + training_dataloader, training_data_iter = get_dataloader_and_iter_lmdb( + _training_data + ) + else: + training_dataloader, training_data_iter = get_dataloader_and_iter( + _training_data, _training_params["training_data"] + ) if _validation_data is not None: - ( - validation_dataloader, - validation_data_iter, - ) = get_dataloader_and_iter( - _validation_data, _training_params["validation_data"] - ) + if isinstance(_validation_data, LmdbDataset): + ( + validation_dataloader, + validation_data_iter, + ) = get_dataloader_and_iter_lmdb(_validation_data) + else: + ( + validation_dataloader, + validation_data_iter, + ) = get_dataloader_and_iter( + _validation_data, _training_params["validation_data"] + ) valid_numb_batch = _training_params["validation_data"].get( "numb_btch", 1 ) @@ -388,12 +460,17 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: self.valid_numb_batch, ) = get_data_loader(training_data, validation_data, training_params) training_data.print_summary( - "training", to_numpy_array(self.training_dataloader.sampler.weights) + "training", + to_numpy_array(self.training_dataloader.sampler.weights) + if not isinstance(training_data, LmdbDataset) + else [1.0], ) if validation_data is not None: validation_data.print_summary( "validation", - to_numpy_array(self.validation_dataloader.sampler.weights), + to_numpy_array(self.validation_dataloader.sampler.weights) + if not isinstance(validation_data, LmdbDataset) + else [1.0], ) else: ( @@ -459,7 +536,9 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: training_data[model_key].print_summary( f"training in {model_key}", - to_numpy_array(self.training_dataloader[model_key].sampler.weights), + to_numpy_array(self.training_dataloader[model_key].sampler.weights) + if not isinstance(training_data[model_key], LmdbDataset) + else [1.0], ) if ( validation_data is not None @@ -469,7 +548,9 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: f"validation in {model_key}", to_numpy_array( self.validation_dataloader[model_key].sampler.weights - ), + ) + if not isinstance(validation_data[model_key], LmdbDataset) + else [1.0], ) # Resolve training steps @@ -482,13 +563,16 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ) if self.num_epoch <= 0: raise ValueError("training.num_epoch must be positive.") - sampler_weights = to_numpy_array( - self.training_dataloader.sampler.weights - ) - total_numb_batch = compute_total_numb_batch( - training_data.index, - sampler_weights, - ) + if isinstance(training_data, LmdbDataset): + total_numb_batch = training_data.total_batch + else: + sampler_weights = to_numpy_array( + self.training_dataloader.sampler.weights + ) + total_numb_batch = compute_total_numb_batch( + training_data.index, + sampler_weights, + ) if total_numb_batch <= 0: raise ValueError( "Total number of training batches must be positive." @@ -508,15 +592,18 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: "are mutually exclusive." ) for model_key in self.model_keys: - sampler_weights = to_numpy_array( - self.training_dataloader[model_key].sampler.weights - ) - per_task_total.append( - compute_total_numb_batch( - training_data[model_key].index, - sampler_weights, + if isinstance(training_data[model_key], LmdbDataset): + per_task_total.append(training_data[model_key].total_batch) + else: + sampler_weights = to_numpy_array( + self.training_dataloader[model_key].sampler.weights + ) + per_task_total.append( + compute_total_numb_batch( + training_data[model_key].index, + sampler_weights, + ) ) - ) ( self.model_prob, self.num_steps, diff --git a/deepmd/pt/utils/lmdb_dataset.py b/deepmd/pt/utils/lmdb_dataset.py new file mode 100644 index 0000000000..44d67be242 --- /dev/null +++ b/deepmd/pt/utils/lmdb_dataset.py @@ -0,0 +1,330 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""PyTorch LMDB dataset — thin wrapper around framework-agnostic LmdbDataReader.""" + +import logging +from collections.abc import ( + Iterator, +) +from typing import ( + Any, +) + +import torch +from torch.utils.data import ( + DataLoader, + Dataset, + Sampler, +) +from torch.utils.data._utils.collate import ( + collate_tensor_fn, +) + +from deepmd.dpmodel.utils.lmdb_data import ( + LmdbDataReader, + LmdbTestData, + SameNlocBatchSampler, + compute_block_targets, + is_lmdb, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + +log = logging.getLogger(__name__) + +# Re-export for backward compatibility +__all__ = [ + "LmdbDataset", + "LmdbTestData", + "_collate_lmdb_batch", + "is_lmdb", +] + + +def _collate_lmdb_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: + """Collate a list of frame dicts into a batch dict. + + All frames in the batch must have the same nloc (enforced by + SameNlocBatchSampler when mixed_batch=False). + + For mixed_batch=True, this function would need padding + mask. + Currently raises NotImplementedError for that case. + """ + if len(batch) > 1: + atypes = [d.get("atype") for d in batch if d.get("atype") is not None] + if atypes and any(len(a) != len(atypes[0]) for a in atypes): + raise NotImplementedError( + "mixed_batch collation (frames with different atom counts " + "in the same batch) is not yet supported. " + "Padding + mask in collate_fn needed." + ) + + example = batch[0] + result: dict[str, Any] = {} + for key in example: + if "find_" in key: + result[key] = batch[0][key] + elif key == "fid": + result[key] = [d[key] for d in batch] + elif key == "type": + continue + elif batch[0][key] is None: + result[key] = None + else: + with torch.device("cpu"): + result[key] = collate_tensor_fn( + [torch.as_tensor(d[key]) for d in batch] + ) + result["sid"] = torch.tensor([0], dtype=torch.long, device="cpu") + return result + + +class _SameNlocBatchSamplerTorch(Sampler): + """Torch Sampler adapter around the framework-agnostic SameNlocBatchSampler. + + PyTorch DataLoader with batch_sampler expects a Sampler that yields + lists of indices. This wraps SameNlocBatchSampler (or + DistributedSameNlocBatchSampler) to satisfy that. + """ + + def __init__(self, inner: SameNlocBatchSampler) -> None: + self._inner = inner + + def __iter__(self) -> Iterator[list[int]]: + yield from self._inner + + def __len__(self) -> int: + return len(self._inner) + + def set_epoch(self, epoch: int) -> None: + """Forward set_epoch to inner sampler if it supports it.""" + if hasattr(self._inner, "set_epoch"): + self._inner.set_epoch(epoch) + + +class LmdbDataset(Dataset): + """PyTorch Dataset backed by LMDB via LmdbDataReader. + + Parameters + ---------- + lmdb_path : str + Path to the LMDB directory. + type_map : list[str] + Global type map from model config. + batch_size : int or str + Batch size. Supports int, "auto", "auto:N". + mixed_batch : bool + If True, allow different nloc in the same batch (future). + If False (default), use SameNlocBatchSampler. + """ + + def __init__( + self, + lmdb_path: str, + type_map: list[str], + batch_size: int | str = "auto", + mixed_batch: bool = False, + auto_prob_style: str | None = None, + ) -> None: + self._reader = LmdbDataReader( + lmdb_path, type_map, batch_size, mixed_batch=mixed_batch + ) + + if mixed_batch: + # Future: DataLoader with padding collate_fn + raise NotImplementedError( + "mixed_batch=True is not yet supported. " + "Requires padding + mask in collate_fn." + ) + + # Compute block_targets from auto_prob_style if provided + self._block_targets = None + if auto_prob_style is not None and self._reader.frame_system_ids is not None: + self._block_targets = compute_block_targets( + auto_prob_style, + self._reader.nsystems, + self._reader.system_nframes, + ) + if self._block_targets is not None: + log.info( + f"LMDB auto_prob: {len(self._block_targets)} blocks, " + f"nsystems={self._reader.nsystems}" + ) + + # Same-nloc batching: use SameNlocBatchSampler + sampler = SameNlocBatchSampler( + self._reader, + shuffle=True, + block_targets=self._block_targets, + ) + self._batch_sampler = _SameNlocBatchSamplerTorch(sampler) + + with torch.device("cpu"): + self._inner_dataloader = DataLoader( + self, + batch_sampler=self._batch_sampler, + num_workers=0, + collate_fn=_collate_lmdb_batch, + ) + + # Per-nloc-group dataloaders for make_stat_input. + # Each group gets its own DataLoader so torch.cat in stat collection + # only concatenates same-shape tensors. + self._nloc_dataloaders: list[DataLoader] = [] + for nloc in sorted(self._reader.nloc_groups.keys()): + indices = self._reader.nloc_groups[nloc] + subset = torch.utils.data.Subset(self, indices) + bs = self._reader.get_batch_size_for_nloc(nloc) + with torch.device("cpu"): + dl = DataLoader( + subset, + batch_size=bs, + shuffle=False, + num_workers=0, + drop_last=False, + collate_fn=_collate_lmdb_batch, + ) + self._nloc_dataloaders.append(dl) + + def __len__(self) -> int: + return len(self._reader) + + def __getitem__(self, index: int) -> dict[str, Any]: + return self._reader[index] + + # --- Delegated to reader --- + + @property + def lmdb_path(self) -> str: + return self._reader.lmdb_path + + @property + def nframes(self) -> int: + return self._reader.nframes + + @property + def mixed_batch(self) -> bool: + return self._reader.mixed_batch + + @property + def mixed_type(self) -> bool: + """LMDB datasets are always mixed_type.""" + return self._reader.mixed_type + + @property + def batch_size(self) -> int: + return self._reader.batch_size + + def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: + self._reader.add_data_requirement(data_requirement) + + def preload_and_modify_all_data_torch(self) -> None: + """No-op: LMDB reads on demand.""" + + def print_summary(self, name: str, prob: Any) -> None: + self._reader.print_summary(name, prob) + if self._block_targets: + reader = self._reader + # Per-block summary: original vs target frames + block_lines = [] + total_original = 0 + total_target = 0 + # Pre-compute block_total_actual for proportional scaling + block_total_actual: list[int] = [] + for sys_ids, target in self._block_targets: + actual = sum(reader.system_nframes[s] for s in sys_ids) + block_total_actual.append(actual) + total_original += actual + total_target += target + # Compact range notation: sys[0-146] instead of sys[0,1,2,...,146] + if len(sys_ids) > 3: + sys_str = f"{sys_ids[0]}-{sys_ids[-1]}" + else: + sys_str = ",".join(str(s) for s in sys_ids) + ratio = target / actual if actual > 0 else 0 + block_lines.append( + f"sys[{sys_str}]({len(sys_ids)}sys): " + f"{actual}->{target} (x{ratio:.2f})" + ) + + # Build sys_id -> block_idx mapping + sys_to_block: dict[int, int] = {} + for blk_idx, (sys_ids, _) in enumerate(self._block_targets): + for sid in sys_ids: + sys_to_block[sid] = blk_idx + + # Compute expanded nloc counts analytically (no actual expansion) + expanded_nloc_info = {} + for nloc, indices in sorted(reader.nloc_groups.items()): + if reader.frame_system_ids is None: + expanded_nloc_info[nloc] = len(indices) + continue + # Count indices per block in this nloc group + blk_counts: dict[int, int] = {} + unassigned = 0 + for idx in indices: + sid = reader.frame_system_ids[idx] + blk = sys_to_block.get(sid) + if blk is not None: + blk_counts[blk] = blk_counts.get(blk, 0) + 1 + else: + unassigned += 1 + expanded = unassigned + for blk_idx, (_, blk_target) in enumerate(self._block_targets): + n_actual = blk_counts.get(blk_idx, 0) + if n_actual == 0: + continue + bta = block_total_actual[blk_idx] + if bta > 0: + t = max(round(blk_target * n_actual / bta), n_actual) + else: + t = n_actual + expanded += t + expanded_nloc_info[nloc] = expanded + + total_expanded = sum(expanded_nloc_info.values()) + n_groups = len(reader.nloc_groups) + ratio_all = total_expanded / total_original if total_original > 0 else 0 + + log.info( + f"LMDB {name} auto_prob: " + f"{total_original}->{total_expanded} frames (x{ratio_all:.2f}), " + f"{n_groups} nloc groups, {len(self._block_targets)} blocks:" + ) + for bl in block_lines: + log.info(f" {bl}") + + def set_noise(self, noise_settings: dict[str, Any]) -> None: + self._reader.set_noise(noise_settings) + + @property + def index(self) -> list[int]: + return self._reader.index + + @property + def total_batch(self) -> int: + return self._reader.total_batch + + @property + def batch_sizes(self) -> list[int]: + return self._reader.batch_sizes + + # --- PyTorch-specific trainer compatibility --- + + @property + def systems(self) -> list: + """One 'system' per nloc group for stat collection compatibility.""" + return [self] * len(self._nloc_dataloaders) + + @property + def dataloaders(self) -> list: + """Per-nloc-group dataloaders for make_stat_input. + + Each dataloader yields batches with uniform nloc, so torch.cat + in stat collection only concatenates same-shape tensors. + """ + return self._nloc_dataloaders + + @property + def sampler_list(self) -> list: + return [self._batch_sampler] diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index dd60d9a7e0..05e9ae60dc 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -791,6 +791,7 @@ def process_systems( If it is a single directory, search for all the systems in the directory. If it is a list, each item in the list is treated as a directory to search. + If it is a single LMDB path, return it directly without expansion. Check if the systems are valid. Parameters @@ -805,6 +806,14 @@ def process_systems( result_systems: list of str The valid systems """ + from deepmd.dpmodel.utils.lmdb_data import ( + is_lmdb, + ) + + # LMDB path: return directly without expansion + if isinstance(systems, str) and is_lmdb(systems): + return [systems] + # Normalize input to a list of paths to search if isinstance(systems, str): search_paths = [systems] diff --git a/examples/lmdb_downsample_data/README.md b/examples/lmdb_downsample_data/README.md new file mode 100644 index 0000000000..e38013ee19 --- /dev/null +++ b/examples/lmdb_downsample_data/README.md @@ -0,0 +1,18 @@ +# LMDB Example Data (Downsampled) + +**WARNING: This data is heavily downsampled and intended ONLY for testing +the LMDB data loading pipeline. Do NOT use it for accuracy benchmarks or +comparisons with the standard npy data format.** + +## Contents + +- `water_training.lmdb` - 80 frames downsampled from `water/data/data_0` +- `water_validation.lmdb` - 20 frames downsampled from `water/data/data_2` +- `input_lmdb.json` - Example training config using LMDB data + +## Usage + +```bash +cd examples/lmdb_downsample_data +dp --pt train input_lmdb.json +``` diff --git a/examples/lmdb_downsample_data/input_lmdb.json b/examples/lmdb_downsample_data/input_lmdb.json new file mode 100644 index 0000000000..6e7f469415 --- /dev/null +++ b/examples/lmdb_downsample_data/input_lmdb.json @@ -0,0 +1,65 @@ +{ + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "seed": 1 + }, + "fitting_net": { + "type": "ener", + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1 + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08 + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0 + }, + "training": { + "training_data": { + "systems": "water_training.lmdb", + "batch_size": "auto" + }, + "validation_data": { + "systems": "water_validation.lmdb", + "batch_size": 1 + }, + "numb_steps": 100, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 10, + "save_freq": 100 + } +} diff --git a/examples/lmdb_downsample_data/water_training.lmdb/data.mdb b/examples/lmdb_downsample_data/water_training.lmdb/data.mdb new file mode 100644 index 0000000000..2b3ecd3bde Binary files /dev/null and b/examples/lmdb_downsample_data/water_training.lmdb/data.mdb differ diff --git a/examples/lmdb_downsample_data/water_training.lmdb/lock.mdb b/examples/lmdb_downsample_data/water_training.lmdb/lock.mdb new file mode 100644 index 0000000000..37d3108c8b Binary files /dev/null and b/examples/lmdb_downsample_data/water_training.lmdb/lock.mdb differ diff --git a/examples/lmdb_downsample_data/water_validation.lmdb/data.mdb b/examples/lmdb_downsample_data/water_validation.lmdb/data.mdb new file mode 100644 index 0000000000..abc9f9cd17 Binary files /dev/null and b/examples/lmdb_downsample_data/water_validation.lmdb/data.mdb differ diff --git a/examples/lmdb_downsample_data/water_validation.lmdb/lock.mdb b/examples/lmdb_downsample_data/water_validation.lmdb/lock.mdb new file mode 100644 index 0000000000..adbfe89e9b Binary files /dev/null and b/examples/lmdb_downsample_data/water_validation.lmdb/lock.mdb differ diff --git a/pyproject.toml b/pyproject.toml index 2f8f86a713..848f539722 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ dependencies = [ 'ml_dtypes', 'mendeleev', 'array-api-compat', + 'lmdb', + 'msgpack', ] requires-python = ">=3.10" keywords = ["deepmd"] diff --git a/source/tests/common/dpmodel/test_lmdb_data.py b/source/tests/common/dpmodel/test_lmdb_data.py new file mode 100644 index 0000000000..ac096633c2 --- /dev/null +++ b/source/tests/common/dpmodel/test_lmdb_data.py @@ -0,0 +1,866 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for LmdbDataReader, LmdbTestData, SameNlocBatchSampler, etc. + +Pure dpmodel (NumPy/lmdb) tests — no PyTorch dependency. +""" + +import tempfile +import unittest + +import lmdb +import msgpack +import numpy as np + +from deepmd.dpmodel.utils.lmdb_data import ( + LmdbDataReader, + LmdbTestData, + SameNlocBatchSampler, + _expand_indices_by_blocks, + compute_block_targets, + is_lmdb, + make_neighbor_stat_data, +) + +# ============================================================ +# LMDB creation helpers +# ============================================================ + + +def _make_frame(natoms: int = 6, seed: int = 0) -> dict: + """Create a synthetic frame dict for testing. + + Generates atom_types with roughly 1/3 type-0 and 2/3 type-1. + """ + rng = np.random.RandomState(seed) + n_type0 = max(1, natoms // 3) + n_type1 = natoms - n_type0 + atype = np.array([0] * n_type0 + [1] * n_type1, dtype=np.int64) + return { + "atom_names": ["O", "H"], + "atom_numbs": [ + { + "type": " str: + """Create a test LMDB database with uniform nloc.""" + n_type0 = max(1, natoms // 3) + n_type1 = natoms - n_type0 + env = lmdb.open(path, map_size=10 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": nframes, + "frame_idx_fmt": "012d", + "system_info": { + "natoms": [n_type0, n_type1], + "formula": "test", + }, + } + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + for i in range(nframes): + key = format(i, "012d").encode() + frame = _make_frame(natoms=natoms, seed=i) + txn.put(key, msgpack.packb(frame, use_bin_type=True)) + env.close() + return path + + +def _create_mixed_nloc_lmdb(path: str) -> str: + """Create an LMDB with frames of different atom counts. + + Frames 0-3: 6 atoms, Frames 4-7: 9 atoms, Frames 8-9: 12 atoms. + """ + frames_spec = [(6, 4), (9, 4), (12, 2)] # (natoms, count) + total = sum(c for _, c in frames_spec) + env = lmdb.open(path, map_size=10 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": total, + "frame_idx_fmt": "012d", + "system_info": { + "natoms": [2, 4], # first frame's type counts + "formula": "mixed", + }, + } + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + idx = 0 + for natoms, count in frames_spec: + for j in range(count): + txn.put( + format(idx, "012d").encode(), + msgpack.packb( + _make_frame(natoms=natoms, seed=idx), use_bin_type=True + ), + ) + idx += 1 + env.close() + return path + + +def _create_lmdb_with_type_map( + path: str, + nframes: int = 6, + natoms: int = 6, + lmdb_type_map: list[str] | None = None, +) -> str: + """Create a test LMDB with type_map stored in metadata.""" + n_type0 = max(1, natoms // 3) + n_type1 = natoms - n_type0 + env = lmdb.open(path, map_size=10 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": nframes, + "frame_idx_fmt": "012d", + "system_info": { + "natoms": [n_type0, n_type1], + }, + } + if lmdb_type_map is not None: + meta["type_map"] = lmdb_type_map + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + for i in range(nframes): + key = format(i, "012d").encode() + frame = _make_frame(natoms=natoms, seed=i) + txn.put(key, msgpack.packb(frame, use_bin_type=True)) + env.close() + return path + + +def _create_lmdb_with_system_ids( + path: str, + system_frames: list[int], + natoms: int = 6, + type_map: list[str] | None = None, +) -> str: + """Create a test LMDB with frame_system_ids in metadata.""" + total = sum(system_frames) + n_type0 = max(1, natoms // 3) + n_type1 = natoms - n_type0 + frame_system_ids = [] + for sid, nf in enumerate(system_frames): + frame_system_ids.extend([sid] * nf) + + env = lmdb.open(path, map_size=50 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": total, + "frame_idx_fmt": "012d", + "system_info": {"natoms": [n_type0, n_type1]}, + "frame_system_ids": frame_system_ids, + "frame_nlocs": [natoms] * total, + } + if type_map is not None: + meta["type_map"] = type_map + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + for i in range(total): + key = format(i, "012d").encode() + frame = _make_frame(natoms=natoms, seed=i % 100) + txn.put(key, msgpack.packb(frame, use_bin_type=True)) + env.close() + return path + + +def _create_grid_lmdb(path: str, nframes: int = 3) -> str: + """Create a test LMDB with 3x3x3 grid of atoms (27 atoms, cell=3A). + + Same geometry as test_neighbor_stat.py: positions at integer coords + (0,1,2)^3, so min_nbor_dist = 1.0. + """ + X, Y, Z = np.mgrid[0:2:3j, 0:2:3j, 0:2:3j] + positions = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T # (27, 3) + natoms = 27 + cell = np.array([3.0, 0, 0, 0, 3.0, 0, 0, 0, 3.0], dtype=np.float64) + atype = np.zeros(natoms, dtype=np.int64) + + env = lmdb.open(path, map_size=10 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": nframes, + "frame_idx_fmt": "012d", + "type_map": ["TYPE"], + "system_info": {"natoms": [natoms], "formula": "grid"}, + } + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + for i in range(nframes): + frame = { + "atom_types": { + "type": " 1 + for i in range(n0, 6): + self.assertEqual(atype[i], 0) # H -> 0 + + def test_reader_remap_superset(self): + reader = LmdbDataReader(self._lmdb_path, ["C", "O", "H"]) + np.testing.assert_array_equal(reader._type_remap, [1, 2]) + + def test_reader_natoms_vec_after_remap(self): + reader = LmdbDataReader(self._lmdb_path, ["H", "O"]) + natoms = reader[0]["natoms"] + self.assertEqual(natoms[0], 6) + self.assertEqual(natoms[2], 4) # H count + self.assertEqual(natoms[3], 2) # O count + + def test_reader_missing_element_raises(self): + with self.assertRaises(ValueError): + LmdbDataReader(self._lmdb_path, ["O"]) + + def test_reader_no_type_map_in_metadata(self): + tmpdir = tempfile.TemporaryDirectory() + path = _create_lmdb_with_type_map( + f"{tmpdir.name}/old.lmdb", nframes=3, natoms=6, lmdb_type_map=None + ) + reader = LmdbDataReader(path, ["H", "O"]) + self.assertIsNone(reader._type_remap) + tmpdir.cleanup() + + def test_testdata_no_remap_when_match(self): + td = LmdbTestData(self._lmdb_path, type_map=["O", "H"], shuffle_test=False) + self.assertIsNone(td._type_remap) + + def test_testdata_remap_when_reversed(self): + td = LmdbTestData(self._lmdb_path, type_map=["H", "O"], shuffle_test=False) + self.assertIsNotNone(td._type_remap) + data = td.get_test() + n0 = max(1, 6 // 3) + for i in range(n0): + self.assertEqual(data["type"][0, i], 1) + for i in range(n0, 6): + self.assertEqual(data["type"][0, i], 0) + + def test_testdata_remap_superset(self): + td = LmdbTestData(self._lmdb_path, type_map=["C", "O", "H"], shuffle_test=False) + self.assertIsNotNone(td._type_remap) + + def test_testdata_missing_element_raises(self): + with self.assertRaises(ValueError): + LmdbTestData(self._lmdb_path, type_map=["O"], shuffle_test=False) + + def test_testdata_no_type_map_in_metadata(self): + tmpdir = tempfile.TemporaryDirectory() + path = _create_lmdb_with_type_map( + f"{tmpdir.name}/old.lmdb", nframes=3, natoms=6, lmdb_type_map=None + ) + td = LmdbTestData(path, type_map=["H", "O"], shuffle_test=False) + self.assertIsNone(td._type_remap) + tmpdir.cleanup() + + +# ============================================================ +# auto_prob / frame_system_ids tests +# ============================================================ + + +class TestAutoProb(unittest.TestCase): + """Test auto_prob support: frame_system_ids, compute_block_targets, + _expand_indices_by_blocks, and SameNlocBatchSampler with block_targets. + """ + + @classmethod + def setUpClass(cls): + cls._tmpdir = tempfile.TemporaryDirectory() + cls._lmdb_path = _create_lmdb_with_system_ids( + f"{cls._tmpdir.name}/auto_prob.lmdb", + system_frames=[100, 200, 300], + natoms=6, + type_map=["O", "H"], + ) + + @classmethod + def tearDownClass(cls): + cls._tmpdir.cleanup() + + def test_reader_system_groups(self): + reader = LmdbDataReader(self._lmdb_path, ["O", "H"]) + self.assertEqual(reader.nsystems, 3) + self.assertEqual(reader.system_nframes, [100, 200, 300]) + self.assertEqual(len(reader.system_groups[0]), 100) + self.assertEqual(len(reader.system_groups[1]), 200) + self.assertEqual(len(reader.system_groups[2]), 300) + + def test_reader_no_system_ids_backward_compat(self): + tmpdir = tempfile.TemporaryDirectory() + path = _create_lmdb(f"{tmpdir.name}/old.lmdb", nframes=10, natoms=6) + reader = LmdbDataReader(path, ["O", "H"]) + self.assertEqual(reader.nsystems, 1) + self.assertIsNone(reader.frame_system_ids) + tmpdir.cleanup() + + def test_compute_block_targets_equal_weight(self): + result = compute_block_targets( + "prob_sys_size;0:1:0.5;1:2:0.5", nsystems=2, system_nframes=[100, 100] + ) + self.assertEqual(result, []) + + def test_compute_block_targets_unequal(self): + result = compute_block_targets( + "prob_sys_size;0:1:0.5;1:2:0.5", nsystems=2, system_nframes=[100, 500] + ) + self.assertEqual(len(result), 2) + self.assertEqual(result[0], ([0], 500)) + self.assertEqual(result[1], ([1], 500)) + + def test_compute_block_targets_multi_sys_block(self): + result = compute_block_targets( + "prob_sys_size;0:2:0.5;2:3:0.5", + nsystems=3, + system_nframes=[100, 200, 300], + ) + self.assertEqual(result, []) + + def test_compute_block_targets_asymmetric(self): + result = compute_block_targets( + "prob_sys_size;0:2:0.5;2:3:0.5", + nsystems=3, + system_nframes=[50, 50, 400], + ) + self.assertEqual(len(result), 2) + self.assertEqual(result[0][0], [0, 1]) + self.assertEqual(result[0][1], 400) + + def test_expand_indices_basic(self): + frame_system_ids = [0] * 5 + [1] * 5 + block_targets = [([0], 25), ([1], 25)] + rng = np.random.default_rng(42) + expanded = _expand_indices_by_blocks( + list(range(10)), frame_system_ids, block_targets, rng + ) + sys0 = [i for i in expanded if frame_system_ids[i] == 0] + sys1 = [i for i in expanded if frame_system_ids[i] == 1] + self.assertEqual(len(sys0), 25) + self.assertEqual(len(sys1), 25) + + def test_expand_indices_no_expansion(self): + frame_system_ids = [0] * 5 + [1] * 5 + block_targets = [([0], 5), ([1], 5)] + rng = np.random.default_rng(42) + expanded = _expand_indices_by_blocks( + list(range(10)), frame_system_ids, block_targets, rng + ) + self.assertEqual(sorted(expanded), list(range(10))) + + def test_expand_indices_remainder_sampling(self): + from collections import ( + Counter, + ) + + frame_system_ids = [0] * 10 + block_targets = [([0], 23)] + rng = np.random.default_rng(42) + expanded = _expand_indices_by_blocks( + list(range(10)), frame_system_ids, block_targets, rng + ) + self.assertEqual(len(expanded), 23) + counts = Counter(expanded) + n_three = sum(1 for c in counts.values() if c == 3) + self.assertEqual(n_three, 3) + + def test_expand_epoch_diversity(self): + frame_system_ids = [0] * 10 + block_targets = [([0], 15)] + results = [] + for seed in range(5): + rng = np.random.default_rng(seed) + expanded = _expand_indices_by_blocks( + list(range(10)), frame_system_ids, block_targets, rng + ) + results.append(sorted(expanded[10:])) + unique = {tuple(r) for r in results} + self.assertGreater(len(unique), 1) + + def test_sampler_with_block_targets(self): + reader = LmdbDataReader(self._lmdb_path, ["O", "H"]) + block_targets = compute_block_targets( + "prob_sys_size;0:1:0.5;1:3:0.5", + nsystems=3, + system_nframes=[100, 200, 300], + ) + sampler = SameNlocBatchSampler( + reader, shuffle=True, block_targets=block_targets + ) + all_indices = [i for batch in sampler for i in batch] + self.assertGreater(len(all_indices), 600) + self.assertEqual(len(set(all_indices)), 600) + + def test_sampler_without_block_targets(self): + reader = LmdbDataReader(self._lmdb_path, ["O", "H"]) + sampler = SameNlocBatchSampler(reader, shuffle=False) + all_indices = [i for batch in sampler for i in batch] + self.assertEqual(sorted(all_indices), list(range(600))) + + +# ============================================================ +# Neighbor stat from LMDB tests +# ============================================================ + + +class TestLmdbNeighborStat(unittest.TestCase): + """Test make_neighbor_stat_data interface and sampling.""" + + @classmethod + def setUpClass(cls): + cls._tmpdir = tempfile.TemporaryDirectory() + cls._lmdb_path = _create_grid_lmdb(f"{cls._tmpdir.name}/grid.lmdb", nframes=3) + + @classmethod + def tearDownClass(cls): + cls._tmpdir.cleanup() + + def test_make_neighbor_stat_data_interface(self): + data = make_neighbor_stat_data(self._lmdb_path, ["TYPE", "NO_TYPE"]) + self.assertIsInstance(data.system_dirs, list) + self.assertGreater(len(data.system_dirs), 0) + self.assertEqual(data.get_ntypes(), 2) + data.get_batch() # no-op + sys0 = data.data_systems[0] + self.assertIsInstance(sys0.pbc, bool) + set_data = sys0._load_set(sys0.dirs[0]) + self.assertEqual(set_data["coord"].ndim, 2) + self.assertEqual(set_data["coord"].shape[1], sys0.get_natoms() * 3) + + def test_sampling_large_dataset(self): + tmpdir = tempfile.TemporaryDirectory() + path = _create_grid_lmdb(f"{tmpdir.name}/large.lmdb", nframes=50) + data = make_neighbor_stat_data(path, ["TYPE"], max_frames=10) + total = sum(s._load_set(s.dirs[0])["coord"].shape[0] for s in data.data_systems) + self.assertEqual(total, 10) + tmpdir.cleanup() + + +def _create_lmdb_with_extra_keys( + path: str, nframes: int = 5, natoms: int = 6, extra_keys: dict | None = None +) -> str: + """Create a test LMDB with extra per-frame keys (e.g. atom_pref, fparam). + + Parameters + ---------- + extra_keys : dict + key -> (shape_fn, dtype) where shape_fn(natoms) returns the array shape. + Example: {"atom_pref": (lambda n: (n,), np.float64)} + """ + n_type0 = max(1, natoms // 3) + n_type1 = natoms - n_type0 + extra_keys = extra_keys or {} + env = lmdb.open(path, map_size=10 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": nframes, + "frame_idx_fmt": "012d", + "type_map": ["O", "H"], + "system_info": {"natoms": [n_type0, n_type1]}, + } + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + rng = np.random.RandomState(0) + for i in range(nframes): + frame = _make_frame(natoms=natoms, seed=i) + for ek, (shape_fn, dtype) in extra_keys.items(): + arr = rng.rand(*shape_fn(natoms)).astype(dtype) + frame[ek] = { + "type": str(arr.dtype), + "shape": list(arr.shape), + "data": arr.tobytes(), + } + txn.put( + format(i, "012d").encode(), + msgpack.packb(frame, use_bin_type=True), + ) + env.close() + return path + + +# ============================================================ +# Dynamic find_* and repeat tests +# ============================================================ + + +class TestDynamicKeysAndRepeat(unittest.TestCase): + """Test auto-discovery of find_* flags and repeat handling.""" + + @classmethod + def setUpClass(cls): + cls._tmpdir = tempfile.TemporaryDirectory() + cls._natoms = 6 + cls._nframes = 5 + cls._lmdb_path = _create_lmdb_with_extra_keys( + f"{cls._tmpdir.name}/extra.lmdb", + nframes=cls._nframes, + natoms=cls._natoms, + extra_keys={ + "atom_pref": (lambda n: (n,), np.float64), + "fparam": (lambda n: (3,), np.float64), + }, + ) + cls._type_map = ["O", "H"] + + @classmethod + def tearDownClass(cls): + cls._tmpdir.cleanup() + + # --- LmdbDataReader --- + + def test_reader_find_flags_auto_detected(self): + """Extra keys in frame get find_*=1.0 automatically.""" + reader = LmdbDataReader(self._lmdb_path, self._type_map) + frame = reader[0] + self.assertEqual(frame["find_atom_pref"], np.float32(1.0)) + self.assertEqual(frame["find_fparam"], np.float32(1.0)) + self.assertEqual(frame["find_energy"], np.float32(1.0)) + # Keys not in frame get find_*=0.0 + self.assertEqual(frame["find_aparam"], np.float32(0.0)) + self.assertEqual(frame["find_spin"], np.float32(0.0)) + + def test_reader_repeat_applied(self): + """DataRequirementItem with repeat=3 expands atom_pref from (natoms,) to (natoms*3,).""" + from deepmd.utils.data import ( + DataRequirementItem, + ) + + reader = LmdbDataReader(self._lmdb_path, self._type_map) + reader.add_data_requirement( + [ + DataRequirementItem( + "atom_pref", + ndof=1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), + ] + ) + frame = reader[0] + self.assertEqual(frame["atom_pref"].shape, (self._natoms * 3,)) + + def test_reader_repeat_default_fill(self): + """Missing key with repeat fills correct shape.""" + from deepmd.utils.data import ( + DataRequirementItem, + ) + + reader = LmdbDataReader(self._lmdb_path, self._type_map) + reader.add_data_requirement( + [ + DataRequirementItem( + "drdq", ndof=6, atomic=True, must=False, high_prec=False, repeat=2 + ), + ] + ) + frame = reader[0] + self.assertEqual(frame["find_drdq"], np.float32(0.0)) + self.assertEqual(frame["drdq"].shape, (self._natoms * 6 * 2,)) + + # --- LmdbTestData --- + + def test_testdata_find_flags_auto_detected(self): + """LmdbTestData.get_test() discovers extra keys dynamically.""" + td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False) + result = td.get_test() + self.assertEqual(result["find_atom_pref"], 1.0) + self.assertEqual(result["find_fparam"], 1.0) + self.assertIn("atom_pref", result) + self.assertIn("fparam", result) + + def test_testdata_repeat_applied(self): + """LmdbTestData respects repeat=3 for atom_pref.""" + td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False) + td.add("atom_pref", 1, atomic=True, must=False, high_prec=False, repeat=3) + result = td.get_test() + self.assertEqual( + result["atom_pref"].shape, + (self._nframes, self._natoms * 3), + ) + + def test_testdata_missing_key_not_found(self): + """Keys absent from LMDB frames get find_*=0.0 in get_test().""" + tmpdir = tempfile.TemporaryDirectory() + path = _create_lmdb(f"{tmpdir.name}/plain.lmdb", nframes=3, natoms=6) + td = LmdbTestData(path, type_map=["O", "H"], shuffle_test=False) + result = td.get_test() + # atom_pref is not in the plain LMDB + self.assertEqual(result.get("find_atom_pref", 0.0), 0.0) + tmpdir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index d463e9cd2b..5f7cd323ee 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -64,6 +64,7 @@ p_examples / "water" / "se_e3_tebd" / "input_torch.json", p_examples / "hessian" / "single_task" / "input.json", p_examples / "water" / "se_e2_a" / "input_torch_num_epoch.json", + p_examples / "lmdb_downsample_data" / "input_lmdb.json", ) input_files_multi = ( diff --git a/source/tests/consistent/test_lmdb_data.py b/source/tests/consistent/test_lmdb_data.py new file mode 100644 index 0000000000..6e9cecee52 --- /dev/null +++ b/source/tests/consistent/test_lmdb_data.py @@ -0,0 +1,227 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Consistency tests: LmdbDataReader (dpmodel) vs LmdbDataset (pt). + +Verifies that the framework-agnostic reader and the PyTorch wrapper +produce identical outputs for the same LMDB data. +""" + +import tempfile +import unittest + +import lmdb +import msgpack +import numpy as np + +from deepmd.dpmodel.utils.lmdb_data import ( + LmdbDataReader, +) + +try: + from deepmd.pt.utils.lmdb_dataset import ( + LmdbDataset, + ) + + INSTALLED_PT = True +except ImportError: + INSTALLED_PT = False + + +def _make_frame(natoms: int = 6, seed: int = 0) -> dict: + """Create a synthetic frame dict for testing.""" + rng = np.random.RandomState(seed) + n_type0 = max(1, natoms // 3) + n_type1 = natoms - n_type0 + atype = np.array([0] * n_type0 + [1] * n_type1, dtype=np.int64) + return { + "atom_names": ["O", "H"], + "atom_numbs": [ + { + "type": " str: + """Create a test LMDB database with uniform nloc.""" + n_type0 = max(1, natoms // 3) + n_type1 = natoms - n_type0 + env = lmdb.open(path, map_size=10 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": nframes, + "frame_idx_fmt": "012d", + "system_info": { + "natoms": [n_type0, n_type1], + "formula": "test", + }, + } + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + for i in range(nframes): + key = format(i, "012d").encode() + frame = _make_frame(natoms=natoms, seed=i) + txn.put(key, msgpack.packb(frame, use_bin_type=True)) + env.close() + return path + + +def _create_mixed_nloc_lmdb(path: str) -> str: + """Create an LMDB with frames of different atom counts.""" + frames_spec = [(6, 4), (9, 4), (12, 2)] + total = sum(c for _, c in frames_spec) + env = lmdb.open(path, map_size=10 * 1024 * 1024) + with env.begin(write=True) as txn: + meta = { + "nframes": total, + "frame_idx_fmt": "012d", + "system_info": {"natoms": [2, 4], "formula": "mixed"}, + } + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + idx = 0 + for natoms, count in frames_spec: + for j in range(count): + txn.put( + format(idx, "012d").encode(), + msgpack.packb( + _make_frame(natoms=natoms, seed=idx), use_bin_type=True + ), + ) + idx += 1 + env.close() + return path + + +def _assert_frames_equal(test_case, frame_dp, frame_pt, frame_idx): + """Assert two frames (from reader and dataset) are identical.""" + test_case.assertEqual( + set(frame_dp.keys()), + set(frame_pt.keys()), + msg=f"frame={frame_idx}", + ) + for key in frame_dp: + dp_val = frame_dp[key] + pt_val = frame_pt[key] + if isinstance(dp_val, np.ndarray): + np.testing.assert_array_equal( + dp_val, pt_val, err_msg=f"key={key}, frame={frame_idx}" + ) + else: + test_case.assertEqual(dp_val, pt_val, msg=f"key={key}, frame={frame_idx}") + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch not available") +class TestLmdbDataConsistency(unittest.TestCase): + """Verify LmdbDataReader (dpmodel) and LmdbDataset (pt) produce identical outputs.""" + + @classmethod + def setUpClass(cls): + cls._tmpdir = tempfile.TemporaryDirectory() + cls._lmdb_path = _create_lmdb( + f"{cls._tmpdir.name}/test.lmdb", nframes=10, natoms=6 + ) + cls._type_map = ["O", "H"] + cls._reader = LmdbDataReader(cls._lmdb_path, cls._type_map, batch_size=2) + cls._ds = LmdbDataset(cls._lmdb_path, cls._type_map, batch_size=2) + + @classmethod + def tearDownClass(cls): + del cls._ds, cls._reader + cls._tmpdir.cleanup() + + def test_same_len(self): + self.assertEqual(len(self._reader), len(self._ds)) + + def test_same_frame_data(self): + for i in range(len(self._reader)): + _assert_frames_equal(self, self._reader[i], self._ds[i], i) + + def test_same_batch_size(self): + reader = LmdbDataReader(self._lmdb_path, self._type_map, batch_size="auto") + ds = LmdbDataset(self._lmdb_path, self._type_map, batch_size="auto") + self.assertEqual(reader.batch_size, ds.batch_size) + + def test_same_properties(self): + self.assertEqual(self._reader.index, self._ds.index) + self.assertEqual(self._reader.total_batch, self._ds.total_batch) + self.assertEqual(self._reader.batch_sizes, self._ds.batch_sizes) + self.assertEqual(self._reader.nframes, self._ds.nframes) + self.assertEqual(self._reader.mixed_type, self._ds.mixed_type) + self.assertEqual(self._reader.mixed_batch, self._ds.mixed_batch) + + def test_data_requirement(self): + req = [ + { + "key": "virial", + "ndof": 9, + "atomic": False, + "must": False, + "high_prec": False, + "repeat": 1, + "default": 0.0, + } + ] + reader = LmdbDataReader(self._lmdb_path, self._type_map, batch_size=2) + ds = LmdbDataset(self._lmdb_path, self._type_map, batch_size=2) + reader.add_data_requirement(req) + ds.add_data_requirement(req) + frame_dp = reader[0] + frame_pt = ds[0] + np.testing.assert_array_equal(frame_dp["virial"], frame_pt["virial"]) + self.assertEqual(frame_dp["find_virial"], frame_pt["find_virial"]) + + def test_mixed_nloc_same_frame_data(self): + """Reader and dataset produce identical frames for mixed atom counts.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = _create_mixed_nloc_lmdb(f"{tmpdir}/mixed.lmdb") + reader = LmdbDataReader(path, self._type_map, batch_size=2) + ds = LmdbDataset(path, self._type_map, batch_size=2) + self.assertEqual(len(reader), len(ds)) + for i in range(len(reader)): + _assert_frames_equal(self, reader[i], ds[i], i) + + def test_mixed_nloc_same_properties(self): + """Reader and dataset agree on properties for mixed-nloc LMDB.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = _create_mixed_nloc_lmdb(f"{tmpdir}/mixed.lmdb") + reader = LmdbDataReader(path, self._type_map, batch_size=2) + ds = LmdbDataset(path, self._type_map, batch_size=2) + self.assertEqual(reader.nframes, ds.nframes) + self.assertEqual(reader.batch_sizes, ds.batch_sizes) + self.assertEqual(reader.mixed_batch, ds.mixed_batch) + self.assertFalse(reader.mixed_batch) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_lmdb_dataloader.py b/source/tests/pt/test_lmdb_dataloader.py new file mode 100644 index 0000000000..ebb505706d --- /dev/null +++ b/source/tests/pt/test_lmdb_dataloader.py @@ -0,0 +1,847 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for LmdbDataset (PyTorch wrapper) and related PT-specific features. + +Pure dpmodel tests (LmdbDataReader, LmdbTestData, SameNlocBatchSampler, type_map +remapping, auto_prob) live in source/tests/common/dpmodel/test_lmdb_data.py. +Consistency tests (dpmodel vs pt) live in source/tests/consistent/test_lmdb_data.py. +""" + +import lmdb +import msgpack +import numpy as np +import pytest +import torch + +from deepmd.dpmodel.utils.lmdb_data import ( + DistributedSameNlocBatchSampler, + LmdbDataReader, + SameNlocBatchSampler, + _decode_frame, + _read_metadata, + _remap_keys, + merge_lmdb, +) +from deepmd.pt.utils.lmdb_dataset import ( + LmdbDataset, + _collate_lmdb_batch, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + + +def _make_frame(natoms: int = 6, seed: int = 0) -> dict: + """Create a synthetic frame dict as stored in LMDB.""" + rng = np.random.RandomState(seed) + + def _encode_array(arr: np.ndarray) -> dict: + return { + "nd": None, + "type": str(arr.dtype), + "kind": "", + "shape": list(arr.shape), + "data": arr.tobytes(), + } + + return { + "atom_numbs": [natoms // 2, natoms // 2], + "atom_names": ["O", "H"], + "atom_types": _encode_array( + np.array([0] * (natoms // 2) + [1] * (natoms // 2), dtype=np.int64) + ), + "orig": _encode_array(np.zeros(3, dtype=np.float64)), + "cells": _encode_array((np.eye(3) * 10.0).astype(np.float64)), + "coords": _encode_array((rng.rand(natoms, 3) * 10.0).astype(np.float64)), + "energies": _encode_array(np.array(rng.randn(), dtype=np.float64)), + "forces": _encode_array(rng.randn(natoms, 3).astype(np.float64)), + } + + +def _create_test_lmdb(path: str, nframes: int = 10, natoms: int = 6) -> None: + """Create a minimal LMDB dataset for testing.""" + env = lmdb.open(path, map_size=10 * 1024 * 1024) + fmt = "012d" + metadata = { + "nframes": nframes, + "frame_idx_fmt": fmt, + "system_info": { + "formula": f"O{natoms // 2}H{natoms // 2}", + "natoms": [natoms // 2, natoms // 2], + "nframes": nframes, + }, + } + with env.begin(write=True) as txn: + txn.put(b"__metadata__", msgpack.packb(metadata, use_bin_type=True)) + for i in range(nframes): + key = format(i, fmt).encode() + frame = _make_frame(natoms=natoms, seed=i) + txn.put(key, msgpack.packb(frame, use_bin_type=True)) + env.close() + + +@pytest.fixture +def lmdb_dir(tmp_path): + """Create a temporary LMDB dataset.""" + lmdb_path = str(tmp_path / "test.lmdb") + _create_test_lmdb(lmdb_path, nframes=10, natoms=6) + return lmdb_path + + +# ============================================================ +# Internal helper functions +# ============================================================ + + +class TestHelpers: + """Test internal helper functions (dpmodel, but only tested here).""" + + def test_read_metadata(self, lmdb_dir): + env = lmdb.open(lmdb_dir, readonly=True, lock=False) + with env.begin() as txn: + meta = _read_metadata(txn) + assert meta["nframes"] == 10 + env.close() + + def test_read_metadata_missing(self, tmp_path): + empty_path = str(tmp_path / "empty.lmdb") + env = lmdb.open(empty_path, map_size=1024 * 1024) + env.close() + env = lmdb.open(empty_path, readonly=True, lock=False) + with env.begin() as txn: + with pytest.raises(ValueError, match="missing __metadata__"): + _read_metadata(txn) + env.close() + + def test_decode_frame(self, lmdb_dir): + env = lmdb.open(lmdb_dir, readonly=True, lock=False) + with env.begin() as txn: + raw = txn.get(format(0, "012d").encode()) + frame = _decode_frame(raw) + assert "coords" in frame + assert isinstance(frame["coords"], np.ndarray) + assert frame["coords"].shape == (6, 3) + env.close() + + def test_remap_keys(self): + frame = { + "coords": np.zeros((3, 3)), + "cells": np.zeros((3, 3)), + "energies": np.array(1.0), + "forces": np.zeros((3, 3)), + "atom_types": np.array([0, 1, 0]), + "custom_key": np.array([42.0]), + } + remapped = _remap_keys(frame) + assert "coord" in remapped + assert "box" in remapped + assert "energy" in remapped + assert "force" in remapped + assert "atype" in remapped + assert "custom_key" in remapped + assert "coords" not in remapped + + +# ============================================================ +# LmdbDataset (PT wrapper) +# ============================================================ + + +class TestLmdbDataset: + """Test LmdbDataset class.""" + + def test_len(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert len(ds) == 10 + + def test_getitem_keys(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + frame = ds[0] + for key in ("coord", "box", "energy", "force", "atype", "natoms", "fid"): + assert key in frame + assert frame["find_energy"] == 1.0 + assert frame["find_force"] == 1.0 + # Metadata keys removed + for key in ("atom_numbs", "atom_names", "orig"): + assert key not in frame + + def test_getitem_shapes(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + frame = ds[0] + assert frame["coord"].shape == (6, 3) + assert frame["box"].shape == (9,) + assert frame["energy"].shape == (1,) + assert frame["force"].shape == (6, 3) + assert frame["atype"].shape == (6,) + assert frame["natoms"].shape == (4,) + + def test_getitem_dtypes(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + frame = ds[0] + assert frame["coord"].dtype == np.float64 + assert frame["atype"].dtype == np.int64 + + def test_getitem_out_of_range(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + with pytest.raises(IndexError): + ds[999] + + def test_natoms_vec(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + natoms = ds[0]["natoms"] + assert natoms[0] == 6 + assert natoms[2] == 3 # O count + assert natoms[3] == 3 # H count + + def test_auto_batch_size(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size="auto") + assert ds.batch_size == 6 + + def test_auto_batch_size_with_rule(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size="auto:12") + assert ds.batch_size == 2 + + def test_int_batch_size(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=3) + assert ds.batch_size == 3 + + def test_mixed_type(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert ds.mixed_type is True + + +# ============================================================ +# Trainer compatibility interface +# ============================================================ + + +class TestTrainerInterface: + """Test Trainer compatibility interface.""" + + def test_systems(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert len(ds.systems) == 1 + assert ds.systems[0] is ds + + def test_dataloaders(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert len(ds.dataloaders) == 1 + + def test_index(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert ds.index == [5] + + def test_total_batch(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert ds.total_batch == 5 + + def test_batch_sizes(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert ds.batch_sizes == [2] + + def test_sampler_list(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + assert len(ds.sampler_list) == 1 + + def test_add_data_requirement(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + req = [DataRequirementItem("virial", 9, atomic=False, must=False, default=0.0)] + ds.add_data_requirement(req) + frame = ds[0] + assert frame["find_virial"] == 0.0 + assert frame["virial"].shape == (9,) + + def test_add_data_requirement_existing_key(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + req = [DataRequirementItem("energy", 1, atomic=False, must=True)] + ds.add_data_requirement(req) + assert ds[0]["find_energy"] == 1.0 + + def test_preload_noop(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + ds.preload_and_modify_all_data_torch() + + def test_set_noise_noop(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + ds.set_noise({}) + + +# ============================================================ +# DataLoader iteration +# ============================================================ + + +class TestDataLoaderIteration: + """Test DataLoader iteration with LmdbDataset.""" + + def test_batch_iteration(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + from torch.utils.data import ( + DataLoader, + ) + + with torch.device("cpu"): + dl = DataLoader( + ds, batch_size=2, shuffle=False, collate_fn=_collate_lmdb_batch + ) + batch = next(iter(dl)) + assert batch["coord"].shape == (2, 6, 3) + assert batch["energy"].shape == (2, 1) + assert batch["atype"].shape == (2, 6) + assert isinstance(batch["fid"], list) + assert batch["sid"] == 0 + + def test_inner_dataloader(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=2) + with torch.device("cpu"): + batch = next(iter(ds.dataloaders[0])) + assert batch["coord"].shape[0] == 2 + + def test_full_epoch(self, lmdb_dir): + ds = LmdbDataset(lmdb_dir, type_map=["O", "H"], batch_size=3) + from torch.utils.data import ( + DataLoader, + ) + + with torch.device("cpu"): + dl = DataLoader( + ds, batch_size=3, shuffle=False, collate_fn=_collate_lmdb_batch + ) + total_frames = sum(batch["coord"].shape[0] for batch in dl) + assert total_frames == 10 + + +# ============================================================ +# Collate function +# ============================================================ + + +class TestCollate: + """Test collate function.""" + + def test_collate_basic(self): + rng = np.random.default_rng(42) + frames = [ + { + "coord": rng.standard_normal((4, 3)), + "energy": np.array([1.0]), + "find_energy": 1.0, + "fid": 0, + }, + { + "coord": rng.standard_normal((4, 3)), + "energy": np.array([2.0]), + "find_energy": 1.0, + "fid": 1, + }, + ] + batch = _collate_lmdb_batch(frames) + assert batch["coord"].shape == (2, 4, 3) + assert batch["fid"] == [0, 1] + assert batch["sid"] == 0 + + def test_collate_skips_type(self): + frames = [ + {"coord": np.zeros((2, 3)), "type": np.array([0, 1])}, + {"coord": np.zeros((2, 3)), "type": np.array([0, 1])}, + ] + assert "type" not in _collate_lmdb_batch(frames) + + def test_collate_none_values(self): + frames = [ + {"coord": np.zeros((2, 3)), "box": None}, + {"coord": np.zeros((2, 3)), "box": None}, + ] + assert _collate_lmdb_batch(frames)["box"] is None + + +# ============================================================ +# Type map remapping (PT-specific: LmdbDataset) +# ============================================================ + + +def _create_test_lmdb_with_type_map( + path: str, + nframes: int = 10, + natoms: int = 6, + lmdb_type_map: list[str] | None = None, +) -> None: + """Create a minimal LMDB dataset with type_map in metadata.""" + env = lmdb.open(path, map_size=10 * 1024 * 1024) + fmt = "012d" + metadata = { + "nframes": nframes, + "frame_idx_fmt": fmt, + "system_info": {"natoms": [natoms // 2, natoms // 2]}, + } + if lmdb_type_map is not None: + metadata["type_map"] = lmdb_type_map + with env.begin(write=True) as txn: + txn.put(b"__metadata__", msgpack.packb(metadata, use_bin_type=True)) + for i in range(nframes): + txn.put( + format(i, fmt).encode(), + msgpack.packb(_make_frame(natoms=natoms, seed=i), use_bin_type=True), + ) + env.close() + + +@pytest.fixture +def lmdb_with_type_map(tmp_path): + lmdb_path = str(tmp_path / "typed.lmdb") + _create_test_lmdb_with_type_map( + lmdb_path, nframes=10, natoms=6, lmdb_type_map=["O", "H"] + ) + return lmdb_path + + +class TestTypeMapRemappingDataset: + """Test type_map remapping in LmdbDataset (PT-specific).""" + + def test_dataset_remap_reversed(self, lmdb_with_type_map): + ds = LmdbDataset(lmdb_with_type_map, type_map=["H", "O"], batch_size=2) + frame = ds[0] + np.testing.assert_array_equal(frame["atype"][:3], [1, 1, 1]) + np.testing.assert_array_equal(frame["atype"][3:], [0, 0, 0]) + + def test_dataset_remap_batch(self, lmdb_with_type_map): + ds = LmdbDataset(lmdb_with_type_map, type_map=["H", "O"], batch_size=2) + with torch.device("cpu"): + batch = next(iter(ds.dataloaders[0])) + for i in range(batch["atype"].shape[0]): + np.testing.assert_array_equal(batch["atype"][i, :3].numpy(), [1, 1, 1]) + np.testing.assert_array_equal(batch["atype"][i, 3:].numpy(), [0, 0, 0]) + + def test_dataset_no_remap_when_match(self, lmdb_with_type_map): + ds = LmdbDataset(lmdb_with_type_map, type_map=["O", "H"], batch_size=2) + np.testing.assert_array_equal(ds[0]["atype"][:3], [0, 0, 0]) + + +# ============================================================ +# Distributed sampler +# ============================================================ + + +def _create_multi_nloc_lmdb(path: str) -> None: + """Create an LMDB with frames of varying nloc for distributed tests.""" + env = lmdb.open(path, map_size=10 * 1024 * 1024) + fmt = "012d" + nframes = 30 + frame_nlocs = [] + with env.begin(write=True) as txn: + idx = 0 + for natoms in [4, 6, 8]: + for i in range(10): + txn.put( + format(idx, fmt).encode(), + msgpack.packb( + _make_frame(natoms=natoms, seed=idx * 100), use_bin_type=True + ), + ) + frame_nlocs.append(natoms) + idx += 1 + txn.put( + b"__metadata__", + msgpack.packb( + {"nframes": nframes, "frame_idx_fmt": fmt, "frame_nlocs": frame_nlocs}, + use_bin_type=True, + ), + ) + env.close() + + +@pytest.fixture +def multi_nloc_lmdb(tmp_path): + lmdb_path = str(tmp_path / "multi_nloc.lmdb") + _create_multi_nloc_lmdb(lmdb_path) + return lmdb_path + + +class TestDistributedSameNlocBatchSampler: + """Test DistributedSameNlocBatchSampler (pure logic, no torch.distributed).""" + + def test_disjoint_batches(self, multi_nloc_lmdb): + reader = LmdbDataReader(multi_nloc_lmdb, type_map=["O", "H"], batch_size=2) + s0 = DistributedSameNlocBatchSampler( + reader, rank=0, world_size=2, shuffle=True, seed=42 + ) + s1 = DistributedSameNlocBatchSampler( + reader, rank=1, world_size=2, shuffle=True, seed=42 + ) + frames0 = {i for batch in s0 for i in batch} + frames1 = {i for batch in s1 for i in batch} + assert frames0 & frames1 == set() + + def test_covers_all_frames(self, multi_nloc_lmdb): + reader = LmdbDataReader(multi_nloc_lmdb, type_map=["O", "H"], batch_size=2) + s0 = DistributedSameNlocBatchSampler( + reader, rank=0, world_size=2, shuffle=True, seed=42 + ) + s1 = DistributedSameNlocBatchSampler( + reader, rank=1, world_size=2, shuffle=True, seed=42 + ) + all_frames = {i for batch in s0 for i in batch} | { + i for batch in s1 for i in batch + } + assert all_frames == set(range(30)) + + def test_len(self, multi_nloc_lmdb): + import math + + reader = LmdbDataReader(multi_nloc_lmdb, type_map=["O", "H"], batch_size=2) + total = len(SameNlocBatchSampler(reader, shuffle=False)) + dist_s = DistributedSameNlocBatchSampler( + reader, rank=0, world_size=2, shuffle=False, seed=0 + ) + assert len(dist_s) == math.ceil(total / 2) + + def test_deterministic(self, multi_nloc_lmdb): + reader = LmdbDataReader(multi_nloc_lmdb, type_map=["O", "H"], batch_size=2) + s1 = DistributedSameNlocBatchSampler( + reader, rank=0, world_size=2, shuffle=True, seed=42 + ) + s2 = DistributedSameNlocBatchSampler( + reader, rank=0, world_size=2, shuffle=True, seed=42 + ) + assert list(s1) == list(s2) + + def test_set_epoch_changes_order(self, multi_nloc_lmdb): + reader = LmdbDataReader(multi_nloc_lmdb, type_map=["O", "H"], batch_size=2) + s = DistributedSameNlocBatchSampler( + reader, rank=0, world_size=2, shuffle=True, seed=42 + ) + s.set_epoch(0) + e0 = list(s) + s.set_epoch(1) + e1 = list(s) + assert e0 != e1 + + def test_single_gpu_fallback(self, multi_nloc_lmdb): + reader = LmdbDataReader(multi_nloc_lmdb, type_map=["O", "H"], batch_size=2) + single = { + i + for batch in SameNlocBatchSampler(reader, shuffle=True, seed=42) + for i in batch + } + dist = { + i + for batch in DistributedSameNlocBatchSampler( + reader, rank=0, world_size=1, shuffle=True, seed=42 + ) + for i in batch + } + assert single == dist == set(range(30)) + + def test_same_nloc_per_batch(self, multi_nloc_lmdb): + reader = LmdbDataReader(multi_nloc_lmdb, type_map=["O", "H"], batch_size=2) + s = DistributedSameNlocBatchSampler( + reader, rank=0, world_size=2, shuffle=True, seed=42 + ) + for batch in s: + nlocs = {reader.frame_nlocs[idx] for idx in batch} + assert len(nlocs) == 1 + + +# ============================================================ +# auto_prob / merge_lmdb (PT-specific: LmdbDataset integration) +# ============================================================ + + +def _create_lmdb_with_system_ids( + path: str, + system_frames: list[int], + natoms: int = 6, + type_map: list[str] | None = None, +) -> str: + total = sum(system_frames) + frame_system_ids = [] + for sid, nf in enumerate(system_frames): + frame_system_ids.extend([sid] * nf) + env = lmdb.open(path, map_size=50 * 1024 * 1024) + fmt = "012d" + with env.begin(write=True) as txn: + meta = { + "nframes": total, + "frame_idx_fmt": fmt, + "system_info": {"natoms": [natoms // 2, natoms // 2]}, + "frame_system_ids": frame_system_ids, + "frame_nlocs": [natoms] * total, + } + if type_map is not None: + meta["type_map"] = type_map + txn.put(b"__metadata__", msgpack.packb(meta, use_bin_type=True)) + for i in range(total): + txn.put( + format(i, fmt).encode(), + msgpack.packb( + _make_frame(natoms=natoms, seed=i % 100), use_bin_type=True + ), + ) + env.close() + return path + + +@pytest.fixture +def auto_prob_lmdb(tmp_path): + path = str(tmp_path / "auto_prob.lmdb") + _create_lmdb_with_system_ids( + path, system_frames=[50, 100, 150], natoms=6, type_map=["O", "H"] + ) + return path + + +class TestAutoProbDataset: + """Test LmdbDataset with auto_prob_style.""" + + def test_dataset_auto_prob_passthrough(self, auto_prob_lmdb): + ds = LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", + ) + assert ds._block_targets is not None + + def test_dataset_auto_prob_none(self, auto_prob_lmdb): + ds = LmdbDataset(auto_prob_lmdb, type_map=["O", "H"], batch_size=4) + assert ds._block_targets is None + + def test_dataset_auto_prob_no_system_ids(self, lmdb_dir): + ds = LmdbDataset( + lmdb_dir, + type_map=["O", "H"], + batch_size=4, + auto_prob_style="prob_sys_size;0:1:1.0", + ) + assert ds._block_targets is None + + def test_dataset_auto_prob_iteration(self, auto_prob_lmdb): + ds = LmdbDataset( + auto_prob_lmdb, + type_map=["O", "H"], + batch_size=4, + auto_prob_style="prob_sys_size;0:1:0.5;1:3:0.5", + ) + count = sum(len(batch) for batch in ds._batch_sampler) + assert count > 300 # expanded + + +class TestMergeLmdbSystemIds: + """Test merge_lmdb propagates frame_system_ids.""" + + def test_merge_propagates_system_ids(self, tmp_path): + src1, src2 = str(tmp_path / "src1.lmdb"), str(tmp_path / "src2.lmdb") + _create_lmdb_with_system_ids( + src1, system_frames=[5, 10], natoms=6, type_map=["O", "H"] + ) + _create_lmdb_with_system_ids( + src2, system_frames=[3, 7], natoms=6, type_map=["O", "H"] + ) + dst = str(tmp_path / "merged.lmdb") + merge_lmdb([src1, src2], dst) + reader = LmdbDataReader(dst, ["O", "H"]) + assert reader.nframes == 25 + assert reader.nsystems == 4 + sids = list(reader.frame_system_ids) + assert sids[:5] == [0] * 5 + assert sids[5:15] == [1] * 10 + assert sids[15:18] == [2] * 3 + assert sids[18:25] == [3] * 7 + + def test_merge_old_lmdb_no_system_ids(self, tmp_path): + src1, src2 = str(tmp_path / "old1.lmdb"), str(tmp_path / "old2.lmdb") + _create_test_lmdb(src1, nframes=5, natoms=6) + _create_test_lmdb(src2, nframes=3, natoms=6) + dst = str(tmp_path / "merged_old.lmdb") + merge_lmdb([src1, src2], dst) + reader = LmdbDataReader(dst, ["O", "H"]) + assert reader.nsystems == 2 + assert list(reader.frame_system_ids[:5]) == [0] * 5 + assert list(reader.frame_system_ids[5:8]) == [1] * 3 + + def test_merge_preserves_type_map(self, tmp_path): + src1, src2 = str(tmp_path / "tm1.lmdb"), str(tmp_path / "tm2.lmdb") + _create_lmdb_with_system_ids( + src1, system_frames=[5], natoms=6, type_map=["O", "H"] + ) + _create_lmdb_with_system_ids( + src2, system_frames=[5], natoms=6, type_map=["O", "H"] + ) + dst = str(tmp_path / "merged_tm.lmdb") + merge_lmdb([src1, src2], dst) + env = lmdb.open(dst, readonly=True, lock=False) + with env.begin() as txn: + meta = _read_metadata(txn) + env.close() + assert meta.get("type_map") == ["O", "H"] + + +# ============================================================ +# Multitask LMDB training +# ============================================================ + + +@pytest.fixture +def multitask_lmdb_setup(tmp_path): + """Create two LMDB datasets and a multitask training config.""" + for name in ("task1_train", "task2_train", "task1_val", "task2_val"): + nf = 20 if "train" in name else 10 + _create_test_lmdb_with_type_map( + str(tmp_path / f"{name}.lmdb"), + nframes=nf, + natoms=6, + lmdb_type_map=["O", "H"], + ) + + config = { + "model": { + "shared_dict": { + "type_map_all": ["O", "H"], + "my_descriptor": { + "type": "se_e2_a", + "sel": [4, 4], + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [4, 8], + "axis_neuron": 4, + "precision": "float64", + }, + "my_fitting": {"neuron": [8, 8], "precision": "float64", "seed": 1}, + }, + "model_dict": { + "model_1": { + "type_map": "type_map_all", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "type_map_all", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 50, + "start_lr": 1e-3, + "stop_lr": 1e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 1, + "start_pref_f": 100, + "limit_pref_f": 1, + "start_pref_v": 0.0, + "limit_pref_v": 0.0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 1, + "start_pref_f": 100, + "limit_pref_f": 1, + "start_pref_v": 0.0, + "limit_pref_v": 0.0, + }, + }, + "training": { + "model_prob": {"model_1": 0.5, "model_2": 0.5}, + "data_dict": { + "model_1": { + "stat_file": str(tmp_path / "stat_model_1.hdf5"), + "training_data": { + "systems": str(tmp_path / "task1_train.lmdb"), + "batch_size": 4, + }, + "validation_data": { + "systems": str(tmp_path / "task1_val.lmdb"), + "batch_size": 2, + }, + }, + "model_2": { + "stat_file": str(tmp_path / "stat_model_2.hdf5"), + "training_data": { + "systems": str(tmp_path / "task2_train.lmdb"), + "batch_size": 4, + }, + "validation_data": { + "systems": str(tmp_path / "task2_val.lmdb"), + "batch_size": 2, + }, + }, + }, + "numb_steps": 5, + "seed": 10, + "disp_file": str(tmp_path / "lcurve.out"), + "disp_freq": 2, + "save_freq": 5, + }, + } + return config, tmp_path + + +class TestMultitaskLmdbTraining: + """Test multitask training with LMDB datasets. + + Uses se_e2_a (not se_atten) to keep memory usage low on CI runners (~7 GB). + All assertions are in a single test to avoid creating multiple trainers. + """ + + def test_multitask_lmdb_end_to_end(self, multitask_lmdb_setup, monkeypatch): + from copy import ( + deepcopy, + ) + + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, + ) + from deepmd.utils.argcheck import ( + normalize, + ) + from deepmd.utils.compat import ( + update_deepmd_input, + ) + + config, tmp_path = multitask_lmdb_setup + monkeypatch.chdir(tmp_path) + config = update_deepmd_input(deepcopy(config), warning=True) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + + # -- trainer init assertions -- + assert trainer.multi_task + assert set(trainer.model_keys) == {"model_1", "model_2"} + + # -- shared params assertions -- + state_dict = trainer.wrapper.model.state_dict() + for key in state_dict: + if "model_1.atomic_model.descriptor" in key: + key2 = key.replace("model_1", "model_2") + assert key2 in state_dict + torch.testing.assert_close(state_dict[key], state_dict[key2]) + + # -- get_data assertions -- + for task_key in ["model_1", "model_2"]: + input_dict, label_dict, log_dict = trainer.get_data( + is_train=True, task_key=task_key + ) + assert "coord" in input_dict + assert "sid" in log_dict + + # -- training run assertions -- + trainer.run() + assert len(list(tmp_path.glob("model.ckpt*.pt"))) > 0 + + # Explicit cleanup to free memory on CI + import gc + + del trainer + gc.collect()