Skip to content

feat(pt): add full validation#5336

Open
OutisLi wants to merge 3 commits intodeepmodeling:masterfrom
OutisLi:pr/val
Open

feat(pt): add full validation#5336
OutisLi wants to merge 3 commits intodeepmodeling:masterfrom
OutisLi:pr/val

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Mar 24, 2026

Summary by CodeRabbit

  • New Features

    • Full validation during training with periodic multi-system evaluation and per-system E/F/V MAE/RMSE metrics.
    • Automatic top‑K best-checkpoint selection/rotation and configurable validation output logging.
    • New configuration options for validation (enable, start, frequency, metric, output, max retained checkpoints).
    • Added read-only access to underlying data system from dataset loader.
  • Tests

    • End-to-end and unit tests covering full-validation behavior, checkpoint rotation, config validation, and start-step resolution.

Copilot AI review requested due to automatic review settings March 24, 2026 03:23
@OutisLi OutisLi changed the title feat: add full validation feat(pt): add full validation Mar 24, 2026
@OutisLi OutisLi marked this pull request as draft March 24, 2026 03:23
@dosubot dosubot bot added the new feature label Mar 24, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds “full validation” support to PyTorch training, including config argcheck validation and checkpoint rotation on best metric.

Changes:

  • Introduces FullValidator to run periodic full-dataset validation, log val.log, and optionally save/rotate best checkpoints.
  • Extends deepmd.utils.argcheck.normalize() to validate full-validation configs and supported metrics/prefactors.
  • Adds unit/integration tests covering metric parsing/start-step resolution, argcheck failures, and best-checkpoint rotation.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
source/tests/pt/test_validation.py Adds unit tests for helper functions and argcheck validation for full validation.
source/tests/pt/test_training.py Adds trainer-level tests for full validation behavior and rejection paths (spin/multi-task).
deepmd/utils/argcheck.py Adds validating config schema and cross-field validation for full validation.
deepmd/pt/train/validation.py Implements FullValidator and full-validation metric/logging utilities.
deepmd/pt/train/training.py Wires FullValidator into the training loop and enforces runtime constraints.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

Adds a configurable full-validation system to training: new validator module, argument schema and checks, Trainer integration to run validations and manage top‑K best checkpoints, and unit tests covering behavior and config validation. Validation runs at configured intervals and writes val.log.

Changes

Cohort / File(s) Summary
Validation implementation
deepmd/pt/train/validation.py
New module implementing FullValidator, FullValidationResult, metric parsing, metric computation (E/F/V MAE/RMSE), distributed coordination, checkpoint top‑K management, resolve_full_validation_start_step(), and SilentAutoBatchSize. Writes/updates val.log.
Trainer integration
deepmd/pt/train/training.py
Integrates full validation: imports FullValidator, computes start step from config["validating"], enforces runtime checks (no multi-task, no spin-energy, requires energy-std loss and validation data, zero_stage constraint), constructs FullValidator, and invokes it at display/logging moments.
CLI/config validation
deepmd/utils/argcheck.py
Adds validating args subtree and helpers: FULL_VALIDATION_METRIC_PREFS, metric normalization/validation/accessors, validating_args(), and validate_full_validation_config() wired into normalize(). Enforces cross-field constraints for full validation.
Tests — validation logic
source/tests/pt/test_validation.py
New unit tests for start-step resolution, best‑checkpoint rotation/reconciliation, restart behavior, and validation-config rejection cases (missing validation data, invalid metric, zero prefactors, nonpositive max_best_ckpt).
Tests — trainer integration
source/tests/pt/test_training.py
Adds TestFullValidation with checkpoint rotation test (patching evaluation to produce predictable metrics) and negative tests rejecting spin-energy loss and multi‑task configs.
Dataset utility
deepmd/pt/utils/dataset.py
Added read-only property DeepmdDataSetForLoader.data_system exposing underlying DeepmdData.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant FullValidator
    participant Model
    participant ValidData
    participant Checkpoint
    participant ValLog

    Trainer->>FullValidator: initialize(validating_params, validation_data, model, train_infos, ...)
    loop Training steps
        Trainer->>Trainer: train iterations
        alt display/logging moment
            Trainer->>FullValidator: run(step_id, display_step, lr, save_checkpoint)
            FullValidator->>FullValidator: should_run(display_step)?
            alt run validation
                FullValidator->>Model: set eval mode
                FullValidator->>ValidData: iterate systems
                loop per system
                    FullValidator->>Model: predict(inputs)
                    Model-->>FullValidator: outputs (E, F, V)
                    FullValidator->>FullValidator: compute per-system metrics
                end
                FullValidator->>FullValidator: aggregate metrics, select metric
                alt new best metric
                    FullValidator->>Checkpoint: save best.ckpt-<rank>-t-<step>.pt
                    FullValidator->>Checkpoint: prune older best ckpts
                end
                FullValidator->>ValLog: append metrics row
                FullValidator-->>Trainer: return FullValidationResult
            else skip
                FullValidator-->>Trainer: return None
            end
        end
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Suggested reviewers

  • njzjz
  • iProzd
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 51.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely identifies the main feature being added: full validation capability for the PyTorch trainer module.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
deepmd/pt/train/validation.py (1)

464-470: Disable autograd for validation forwards.

eval() changes module behavior, but gradients are still tracked here. On a full-dataset validation pass that is unnecessary memory and latency overhead.

Proposed refactor
-            batch_output = self.model(
-                coord_input,
-                type_input,
-                box=box_input,
-                fparam=fparam_input,
-                aparam=aparam_input,
-            )
+            with torch.inference_mode():
+                batch_output = self.model(
+                    coord_input,
+                    type_input,
+                    box=box_input,
+                    fparam=fparam_input,
+                    aparam=aparam_input,
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/validation.py` around lines 464 - 470, The validation forward
is running with gradients enabled; wrap the model inference call that produces
batch_output = self.model(...) in a no-grad context (preferably with
torch.inference_mode() or with torch.no_grad()) inside the validation routine
(the method where self.model is called for validation) so autograd is disabled
during validation forwards and reduces memory/latency overhead.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 1421-1427: FullValidator.run currently receives self.save_model
directly, causing every stage-0 worker to build model_state and deepcopy
optimizer.state_dict() when a new best checkpoint is broadcast; change the call
so save_checkpoint is True only on the rank that should serialize (e.g. pass
save_checkpoint=(self.save_model and (self.global_rank == 0)) or use
torch.distributed.get_rank()==0 when self.stage == 0), i.e. gate the
save_checkpoint argument before calling self.full_validator.run to skip
serialization on nonzero ranks and avoid unnecessary deep copies.
- Around line 900-914: FullValidator.run() can deadlock other ranks if rank 0
raises during _evaluate()/save_checkpoint() because those ranks wait on
broadcast_object_list(); wrap the rank-0 evaluation/checkpoint block in a
try/except that captures any Exception, set a serializable error payload (e.g.,
tuple with True and the exception string), and immediately broadcast that
payload with broadcast_object_list() so all ranks are unblocked; on non-zero
ranks receive the payload, detect the error flag, and raise a matching exception
(or handle/clean up) so the failure is propagated instead of leaving ranks
blocked—modify deepmd/pt/train/validation.py FullValidator.run() to implement
this pattern around _evaluate() and save_checkpoint().

In `@deepmd/pt/train/validation.py`:
- Around line 255-259: The validator is disabled when start_step equals
num_steps due to a strict '<' check; update the initialization of self.enabled
(which uses self.full_validation, self.start_step, and num_steps) to allow
equality (use '<=' semantics) so full validation can run on the final training
step, and ensure the should_run() logic remains consistent with this change.
- Around line 307-328: The current code only calls self._evaluate() on rank 0
which deadlocks when self.zero_stage >= 2 because forward passes require all
ranks; change the control flow so that when self.zero_stage >= 2 you call
self._evaluate() on every rank (remove the rank==0-only guard for that case) and
still use save_path = [None] + dist.broadcast_object_list(save_path, src=0) to
propagate the chosen checkpoint; keep the existing rank-0-only actions (calling
self._prune_best_checkpoints and self._log_result) but ensure
save_checkpoint(Path(save_path[0]), ...) and the broadcast happen after every
rank has produced or received save_path; update the branches around
self._evaluate, save_path, dist.broadcast_object_list, save_checkpoint,
_prune_best_checkpoints and _log_result accordingly so distributed stage-2/3
training doesn't hang.

In `@deepmd/utils/argcheck.py`:
- Around line 4180-4194: The code currently returns early on multi_task or
non-'ener' losses which lets validating.full_validation silently pass; instead,
check validating.get("full_validation") first and if true reject unsupported
modes: if multi_task is True or loss_params.get("type","ener") != "ener" raise a
ValueError explaining that full_validation is unsupported with multi-task or
non-'ener' losses. Also only run the validation_metric check (using
validating["validation_metric"], is_valid_full_validation_metric and
FULL_VALIDATION_METRIC_PREFS) when full_validation is enabled so invalid metrics
are rejected rather than silently ignored.

In `@source/tests/pt/test_validation.py`:
- Around line 135-139: The test test_normalize_rejects_invalid_metric currently
catches the broad Exception; replace this with the concrete validation error
type that normalize() raises (e.g., ValidationError or the project-specific
ValidationError class) and update imports accordingly so the assertion uses
assertRaisesRegex(ValidationError, "validation_metric") against
normalize(config); keep the same regex and test flow but narrow the exception to
the specific validation error class.

---

Nitpick comments:
In `@deepmd/pt/train/validation.py`:
- Around line 464-470: The validation forward is running with gradients enabled;
wrap the model inference call that produces batch_output = self.model(...) in a
no-grad context (preferably with torch.inference_mode() or with torch.no_grad())
inside the validation routine (the method where self.model is called for
validation) so autograd is disabled during validation forwards and reduces
memory/latency overhead.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 98de0e8f-e132-4770-ae44-25a71d16e73e

📥 Commits

Reviewing files that changed from the base of the PR and between 034e613 and 0f50be0.

📒 Files selected for processing (5)
  • deepmd/pt/train/training.py
  • deepmd/pt/train/validation.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_training.py
  • source/tests/pt/test_validation.py

@codecov
Copy link

codecov bot commented Mar 24, 2026

Codecov Report

❌ Patch coverage is 75.80645% with 105 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.24%. Comparing base (e97967b) to head (ee2df8e).

Files with missing lines Patch % Lines
deepmd/pt/train/validation.py 74.38% 93 Missing ⚠️
deepmd/utils/argcheck.py 86.95% 6 Missing ⚠️
deepmd/pt/train/training.py 77.27% 5 Missing ⚠️
deepmd/pt/utils/dataset.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5336      +/-   ##
==========================================
- Coverage   82.28%   82.24%   -0.04%     
==========================================
  Files         797      798       +1     
  Lines       82100    82535     +435     
  Branches     4003     4003              
==========================================
+ Hits        67557    67885     +328     
- Misses      13336    13442     +106     
- Partials     1207     1208       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi marked this pull request as ready for review March 26, 2026 06:34
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
deepmd/pt/train/validation.py (1)

318-338: ⚠️ Potential issue | 🟠 Major

Add runtime guard for zero_stage >= 2.

The documentation in argcheck.py states that zero_stage >= 2 is not supported with full validation, but there's no runtime enforcement. With FSDP2 (stage 2/3), model forward passes require collective participation from all ranks. Since only rank 0 enters _evaluate() (line 330), this will hang when other ranks block on the barrier or broadcast.

Proposed fix
         if not self.should_run(display_step):
             return None

+        if self.is_distributed and self.zero_stage >= 2:
+            raise ValueError(
+                "validating.full_validation does not support training.zero_stage >= 2. "
+                "FSDP2 requires all ranks to participate in forward passes."
+            )
+
         if self.is_distributed:
             dist.barrier()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/validation.py` around lines 318 - 338, Add a runtime guard
for zero_stage >= 2 before any distributed synchronization or calling _evaluate:
check the model/training config value (e.g. self.zero_stage or
self.config.zero_stage) right after should_run(...) and, if >= 2, ensure all
ranks take the same action (either return None on all ranks or raise a
RuntimeError on all ranks) and log a clear message; do not let only rank 0 call
_evaluate() while others wait on dist.barrier() — perform the guard before
is_distributed/dist.barrier() so no rank blocks.
🧹 Nitpick comments (1)
source/tests/pt/test_training.py (1)

818-822: Clarify or relax the val.log content assertions.

The test checks val_lines[0].split()[1] == "1000.0" and val_lines[1].split()[1] == "2000.0", which appear to expect the MAE values multiplied by 1000 (from mae_e_per_atom values of 1.0 and 2.0). This relies on implementation details of the log format that may change.

Consider either:

  1. Adding a comment explaining the expected format (e.g., # val.log format: step metric_meV ...)
  2. Using a more flexible assertion like checking the line count or that lines contain expected step numbers
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_training.py` around lines 818 - 822, The assertions on
val.log are too brittle because they depend on exact formatted metric values;
update the test in test_training.py to either (a) add a clarifying comment above
the val.log checks describing the expected file format (e.g., "# val.log format:
step metric_meV ...") or (b) relax the assertions by checking structural
properties instead of exact strings — for example parse each non-comment line
into tokens via val_lines[i].split(), assert the step token equals the expected
steps (e.g., "1000" and "2000") and/or parse the metric token to float and
compare with the expected value using a numeric tolerance or scaled comparison
rather than exact string equality for val_lines[0].split()[1] and
val_lines[1].split()[1].
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@deepmd/pt/train/validation.py`:
- Around line 318-338: Add a runtime guard for zero_stage >= 2 before any
distributed synchronization or calling _evaluate: check the model/training
config value (e.g. self.zero_stage or self.config.zero_stage) right after
should_run(...) and, if >= 2, ensure all ranks take the same action (either
return None on all ranks or raise a RuntimeError on all ranks) and log a clear
message; do not let only rank 0 call _evaluate() while others wait on
dist.barrier() — perform the guard before is_distributed/dist.barrier() so no
rank blocks.

---

Nitpick comments:
In `@source/tests/pt/test_training.py`:
- Around line 818-822: The assertions on val.log are too brittle because they
depend on exact formatted metric values; update the test in test_training.py to
either (a) add a clarifying comment above the val.log checks describing the
expected file format (e.g., "# val.log format: step metric_meV ...") or (b)
relax the assertions by checking structural properties instead of exact strings
— for example parse each non-comment line into tokens via val_lines[i].split(),
assert the step token equals the expected steps (e.g., "1000" and "2000") and/or
parse the metric token to float and compare with the expected value using a
numeric tolerance or scaled comparison rather than exact string equality for
val_lines[0].split()[1] and val_lines[1].split()[1].

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: bcbdf8fa-2d30-4b2d-905b-49e7034d1081

📥 Commits

Reviewing files that changed from the base of the PR and between 0f50be0 and ee2df8e.

📒 Files selected for processing (6)
  • deepmd/pt/train/training.py
  • deepmd/pt/train/validation.py
  • deepmd/pt/utils/dataset.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_training.py
  • source/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (1)
  • source/tests/pt/test_validation.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants