Conversation
There was a problem hiding this comment.
Pull request overview
Adds “full validation” support to PyTorch training, including config argcheck validation and checkpoint rotation on best metric.
Changes:
- Introduces
FullValidatorto run periodic full-dataset validation, logval.log, and optionally save/rotate best checkpoints. - Extends
deepmd.utils.argcheck.normalize()to validate full-validation configs and supported metrics/prefactors. - Adds unit/integration tests covering metric parsing/start-step resolution, argcheck failures, and best-checkpoint rotation.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/test_validation.py | Adds unit tests for helper functions and argcheck validation for full validation. |
| source/tests/pt/test_training.py | Adds trainer-level tests for full validation behavior and rejection paths (spin/multi-task). |
| deepmd/utils/argcheck.py | Adds validating config schema and cross-field validation for full validation. |
| deepmd/pt/train/validation.py | Implements FullValidator and full-validation metric/logging utilities. |
| deepmd/pt/train/training.py | Wires FullValidator into the training loop and enforces runtime constraints. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughAdds a configurable full-validation system to training: new validator module, argument schema and checks, Trainer integration to run validations and manage top‑K best checkpoints, and unit tests covering behavior and config validation. Validation runs at configured intervals and writes Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant FullValidator
participant Model
participant ValidData
participant Checkpoint
participant ValLog
Trainer->>FullValidator: initialize(validating_params, validation_data, model, train_infos, ...)
loop Training steps
Trainer->>Trainer: train iterations
alt display/logging moment
Trainer->>FullValidator: run(step_id, display_step, lr, save_checkpoint)
FullValidator->>FullValidator: should_run(display_step)?
alt run validation
FullValidator->>Model: set eval mode
FullValidator->>ValidData: iterate systems
loop per system
FullValidator->>Model: predict(inputs)
Model-->>FullValidator: outputs (E, F, V)
FullValidator->>FullValidator: compute per-system metrics
end
FullValidator->>FullValidator: aggregate metrics, select metric
alt new best metric
FullValidator->>Checkpoint: save best.ckpt-<rank>-t-<step>.pt
FullValidator->>Checkpoint: prune older best ckpts
end
FullValidator->>ValLog: append metrics row
FullValidator-->>Trainer: return FullValidationResult
else skip
FullValidator-->>Trainer: return None
end
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
deepmd/pt/train/validation.py (1)
464-470: Disable autograd for validation forwards.
eval()changes module behavior, but gradients are still tracked here. On a full-dataset validation pass that is unnecessary memory and latency overhead.Proposed refactor
- batch_output = self.model( - coord_input, - type_input, - box=box_input, - fparam=fparam_input, - aparam=aparam_input, - ) + with torch.inference_mode(): + batch_output = self.model( + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 464 - 470, The validation forward is running with gradients enabled; wrap the model inference call that produces batch_output = self.model(...) in a no-grad context (preferably with torch.inference_mode() or with torch.no_grad()) inside the validation routine (the method where self.model is called for validation) so autograd is disabled during validation forwards and reduces memory/latency overhead.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 1421-1427: FullValidator.run currently receives self.save_model
directly, causing every stage-0 worker to build model_state and deepcopy
optimizer.state_dict() when a new best checkpoint is broadcast; change the call
so save_checkpoint is True only on the rank that should serialize (e.g. pass
save_checkpoint=(self.save_model and (self.global_rank == 0)) or use
torch.distributed.get_rank()==0 when self.stage == 0), i.e. gate the
save_checkpoint argument before calling self.full_validator.run to skip
serialization on nonzero ranks and avoid unnecessary deep copies.
- Around line 900-914: FullValidator.run() can deadlock other ranks if rank 0
raises during _evaluate()/save_checkpoint() because those ranks wait on
broadcast_object_list(); wrap the rank-0 evaluation/checkpoint block in a
try/except that captures any Exception, set a serializable error payload (e.g.,
tuple with True and the exception string), and immediately broadcast that
payload with broadcast_object_list() so all ranks are unblocked; on non-zero
ranks receive the payload, detect the error flag, and raise a matching exception
(or handle/clean up) so the failure is propagated instead of leaving ranks
blocked—modify deepmd/pt/train/validation.py FullValidator.run() to implement
this pattern around _evaluate() and save_checkpoint().
In `@deepmd/pt/train/validation.py`:
- Around line 255-259: The validator is disabled when start_step equals
num_steps due to a strict '<' check; update the initialization of self.enabled
(which uses self.full_validation, self.start_step, and num_steps) to allow
equality (use '<=' semantics) so full validation can run on the final training
step, and ensure the should_run() logic remains consistent with this change.
- Around line 307-328: The current code only calls self._evaluate() on rank 0
which deadlocks when self.zero_stage >= 2 because forward passes require all
ranks; change the control flow so that when self.zero_stage >= 2 you call
self._evaluate() on every rank (remove the rank==0-only guard for that case) and
still use save_path = [None] + dist.broadcast_object_list(save_path, src=0) to
propagate the chosen checkpoint; keep the existing rank-0-only actions (calling
self._prune_best_checkpoints and self._log_result) but ensure
save_checkpoint(Path(save_path[0]), ...) and the broadcast happen after every
rank has produced or received save_path; update the branches around
self._evaluate, save_path, dist.broadcast_object_list, save_checkpoint,
_prune_best_checkpoints and _log_result accordingly so distributed stage-2/3
training doesn't hang.
In `@deepmd/utils/argcheck.py`:
- Around line 4180-4194: The code currently returns early on multi_task or
non-'ener' losses which lets validating.full_validation silently pass; instead,
check validating.get("full_validation") first and if true reject unsupported
modes: if multi_task is True or loss_params.get("type","ener") != "ener" raise a
ValueError explaining that full_validation is unsupported with multi-task or
non-'ener' losses. Also only run the validation_metric check (using
validating["validation_metric"], is_valid_full_validation_metric and
FULL_VALIDATION_METRIC_PREFS) when full_validation is enabled so invalid metrics
are rejected rather than silently ignored.
In `@source/tests/pt/test_validation.py`:
- Around line 135-139: The test test_normalize_rejects_invalid_metric currently
catches the broad Exception; replace this with the concrete validation error
type that normalize() raises (e.g., ValidationError or the project-specific
ValidationError class) and update imports accordingly so the assertion uses
assertRaisesRegex(ValidationError, "validation_metric") against
normalize(config); keep the same regex and test flow but narrow the exception to
the specific validation error class.
---
Nitpick comments:
In `@deepmd/pt/train/validation.py`:
- Around line 464-470: The validation forward is running with gradients enabled;
wrap the model inference call that produces batch_output = self.model(...) in a
no-grad context (preferably with torch.inference_mode() or with torch.no_grad())
inside the validation routine (the method where self.model is called for
validation) so autograd is disabled during validation forwards and reduces
memory/latency overhead.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 98de0e8f-e132-4770-ae44-25a71d16e73e
📒 Files selected for processing (5)
deepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5336 +/- ##
==========================================
- Coverage 82.28% 82.24% -0.04%
==========================================
Files 797 798 +1
Lines 82100 82535 +435
Branches 4003 4003
==========================================
+ Hits 67557 67885 +328
- Misses 13336 13442 +106
- Partials 1207 1208 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
♻️ Duplicate comments (1)
deepmd/pt/train/validation.py (1)
318-338:⚠️ Potential issue | 🟠 MajorAdd runtime guard for
zero_stage >= 2.The documentation in
argcheck.pystates thatzero_stage >= 2is not supported with full validation, but there's no runtime enforcement. With FSDP2 (stage 2/3), model forward passes require collective participation from all ranks. Since only rank 0 enters_evaluate()(line 330), this will hang when other ranks block on the barrier or broadcast.Proposed fix
if not self.should_run(display_step): return None + if self.is_distributed and self.zero_stage >= 2: + raise ValueError( + "validating.full_validation does not support training.zero_stage >= 2. " + "FSDP2 requires all ranks to participate in forward passes." + ) + if self.is_distributed: dist.barrier()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/train/validation.py` around lines 318 - 338, Add a runtime guard for zero_stage >= 2 before any distributed synchronization or calling _evaluate: check the model/training config value (e.g. self.zero_stage or self.config.zero_stage) right after should_run(...) and, if >= 2, ensure all ranks take the same action (either return None on all ranks or raise a RuntimeError on all ranks) and log a clear message; do not let only rank 0 call _evaluate() while others wait on dist.barrier() — perform the guard before is_distributed/dist.barrier() so no rank blocks.
🧹 Nitpick comments (1)
source/tests/pt/test_training.py (1)
818-822: Clarify or relax the val.log content assertions.The test checks
val_lines[0].split()[1] == "1000.0"andval_lines[1].split()[1] == "2000.0", which appear to expect the MAE values multiplied by 1000 (frommae_e_per_atomvalues of 1.0 and 2.0). This relies on implementation details of the log format that may change.Consider either:
- Adding a comment explaining the expected format (e.g.,
# val.log format: step metric_meV ...)- Using a more flexible assertion like checking the line count or that lines contain expected step numbers
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_training.py` around lines 818 - 822, The assertions on val.log are too brittle because they depend on exact formatted metric values; update the test in test_training.py to either (a) add a clarifying comment above the val.log checks describing the expected file format (e.g., "# val.log format: step metric_meV ...") or (b) relax the assertions by checking structural properties instead of exact strings — for example parse each non-comment line into tokens via val_lines[i].split(), assert the step token equals the expected steps (e.g., "1000" and "2000") and/or parse the metric token to float and compare with the expected value using a numeric tolerance or scaled comparison rather than exact string equality for val_lines[0].split()[1] and val_lines[1].split()[1].
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@deepmd/pt/train/validation.py`:
- Around line 318-338: Add a runtime guard for zero_stage >= 2 before any
distributed synchronization or calling _evaluate: check the model/training
config value (e.g. self.zero_stage or self.config.zero_stage) right after
should_run(...) and, if >= 2, ensure all ranks take the same action (either
return None on all ranks or raise a RuntimeError on all ranks) and log a clear
message; do not let only rank 0 call _evaluate() while others wait on
dist.barrier() — perform the guard before is_distributed/dist.barrier() so no
rank blocks.
---
Nitpick comments:
In `@source/tests/pt/test_training.py`:
- Around line 818-822: The assertions on val.log are too brittle because they
depend on exact formatted metric values; update the test in test_training.py to
either (a) add a clarifying comment above the val.log checks describing the
expected file format (e.g., "# val.log format: step metric_meV ...") or (b)
relax the assertions by checking structural properties instead of exact strings
— for example parse each non-comment line into tokens via val_lines[i].split(),
assert the step token equals the expected steps (e.g., "1000" and "2000") and/or
parse the metric token to float and compare with the expected value using a
numeric tolerance or scaled comparison rather than exact string equality for
val_lines[0].split()[1] and val_lines[1].split()[1].
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: bcbdf8fa-2d30-4b2d-905b-49e7034d1081
📒 Files selected for processing (6)
deepmd/pt/train/training.pydeepmd/pt/train/validation.pydeepmd/pt/utils/dataset.pydeepmd/utils/argcheck.pysource/tests/pt/test_training.pysource/tests/pt/test_validation.py
✅ Files skipped from review due to trivial changes (1)
- source/tests/pt/test_validation.py
Summary by CodeRabbit
New Features
Tests