Skip to content

fix(pt): address hessian review comments#5358

Closed
njzjz-bot wants to merge 2 commits intodeepmodeling:masterfrom
njzjz-bothub:fix-hessian-nan-dpa3
Closed

fix(pt): address hessian review comments#5358
njzjz-bot wants to merge 2 commits intodeepmodeling:masterfrom
njzjz-bothub:fix-hessian-nan-dpa3

Conversation

@njzjz-bot
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot commented Mar 30, 2026

Problem

  • Address review comments on the NaN Hessian fix for PyTorch DPA2/DPA3.
  • Ensure safe_for_norm matches vector-norm semantics when dim is not provided.

Change

  • Use torch.linalg.vector_norm(...) in safe_for_norm for both dim=None and dimensioned calls.
  • Reuse the precomputed mask instead of recomputing the squared sum.
  • Add focused PT regression tests to assert DPA2 (repformer.direct_dist=True) and DPA3 Hessians remain finite.

Notes

  • Follow-up branch for comments on fix(pt): fix NaN Hessian in DPA2 and DPA3 #5351.
  • Local full test execution is blocked in this environment by missing DeePMD build artifacts / test runtime setup, so this patch is scoped to the review feedback itself.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)

Summary by CodeRabbit

Release Notes

  • Bug Fixes

    • Enhanced numerical stability in Hessian gradient computations across descriptor layers to ensure finite output values and prevent numerical errors.
  • Tests

    • Added test coverage for Hessian finiteness across DPA2 and DPA3 models with varying precision settings.

njzjz and others added 2 commits March 29, 2026 01:16
Use vector_norm semantics in safe_for_norm and add focused
regression tests to verify DPA2/DPA3 Hessians stay finite.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
@dosubot dosubot bot added the bug label Mar 30, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 30, 2026

📝 Walkthrough

Walkthrough

This PR introduces a safe_for_norm utility function in a new module to compute norms while avoiding problematic gradients at zero, and integrates it into the repflows and repformers descriptor modules by replacing direct torch.linalg.norm calls. A new test module validates Hessian finiteness for DPA models under these changes.

Changes

Cohort / File(s) Summary
Safe Gradient Utilities
deepmd/pt/utils/safe_gradient.py
New module with safe_for_sqrt and safe_for_norm helper functions that compute norms and square roots while masking operations to ensure zero gradients at zero magnitude.
Descriptor Module Refactoring
deepmd/pt/model/descriptor/repflows.py, deepmd/pt/model/descriptor/repformers.py
Replaced direct torch.linalg.norm calls with safe_for_norm wrapper in angle neighbor cutoff masking, edge feature initialization, and angular difference normalization to prevent NaN gradients.
Test Validation
source/tests/pt/model/test_dpa_hessian_finite.py
New test module validating Hessian finiteness for DPA2 and DPA3 configurations with safe norm integration and Hessian computation enabled.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Possibly related PRs

Suggested reviewers

  • iProzd
  • caic99
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly references addressing hessian review comments, which aligns with the PR's primary objective of fixing NaN Hessian issues in PyTorch DPA2/DPA3 by refactoring norm computations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
deepmd/pt/model/descriptor/repflows.py (1)

479-479: Compute norm only for the sliced neighbor window to avoid extra work.

Line 479 computes norms over all neighbors and slices afterward. You can slice first and reduce unnecessary FLOPs when self.a_sel < nnei.

♻️ Suggested change
-        a_dist_mask = (safe_for_norm(diff, dim=-1) < self.a_rcut)[:, :, : self.a_sel]
+        a_dist_mask = (
+            safe_for_norm(diff[:, :, : self.a_sel], dim=-1) < self.a_rcut
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/model/descriptor/repflows.py` at line 479, The current code
computes norms over the full neighbor set then slices, causing extra work;
instead slice the neighbor window first and compute safe_for_norm on the sliced
tensor so a_dist_mask is computed from the norms of diff[:, :, :self.a_sel] only
— i.e., take diff_window = diff[:, :, :self.a_sel], run
safe_for_norm(diff_window, dim=-1) and compare with self.a_rcut to produce
a_dist_mask; update any downstream uses that assumed full-length tensors
accordingly (symbols: a_dist_mask, safe_for_norm, diff, self.a_rcut, self.a_sel,
nnei).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/utils/safe_gradient.py`:
- Line 22: Rename the parameter named "ord" in the function in
deepmd/pt/utils/safe_gradient.py to a non-builtin name (e.g., "norm_ord" or "p")
to avoid shadowing Python's ord builtin; update the function signature and every
use of that parameter within the function (and any internal helper closures) to
the new name, and propagate the rename to any local callers in the same module
so references remain consistent.

---

Nitpick comments:
In `@deepmd/pt/model/descriptor/repflows.py`:
- Line 479: The current code computes norms over the full neighbor set then
slices, causing extra work; instead slice the neighbor window first and compute
safe_for_norm on the sliced tensor so a_dist_mask is computed from the norms of
diff[:, :, :self.a_sel] only — i.e., take diff_window = diff[:, :, :self.a_sel],
run safe_for_norm(diff_window, dim=-1) and compare with self.a_rcut to produce
a_dist_mask; update any downstream uses that assumed full-length tensors
accordingly (symbols: a_dist_mask, safe_for_norm, diff, self.a_rcut, self.a_sel,
nnei).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: df7a4b50-c848-4af4-b8b8-77c7b3426e24

📥 Commits

Reviewing files that changed from the base of the PR and between 14c349b and 5f0ae43.

📒 Files selected for processing (4)
  • deepmd/pt/model/descriptor/repflows.py
  • deepmd/pt/model/descriptor/repformers.py
  • deepmd/pt/utils/safe_gradient.py
  • source/tests/pt/model/test_dpa_hessian_finite.py

x: torch.Tensor,
dim: int | None = None,
keepdim: bool = False,
ord: float = 2.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
ruff check deepmd/pt/utils/safe_gradient.py
ruff format --check deepmd/pt/utils/safe_gradient.py

Repository: deepmodeling/deepmd-kit

Length of output: 578


Rename ord argument to avoid Ruff A002 error.

Line 22 shadows Python's builtin ord, which Ruff flags as A002. This will cause CI to fail per the coding guidelines.

🔧 Proposed fix
 def safe_for_norm(
     x: torch.Tensor,
     dim: int | None = None,
     keepdim: bool = False,
-    ord: float = 2.0,
+    norm_ord: float = 2.0,
 ) -> torch.Tensor:
@@
-        norm = torch.linalg.vector_norm(x_safe, ord=ord)
+        norm = torch.linalg.vector_norm(x_safe, ord=norm_ord)
@@
-    norm = torch.linalg.vector_norm(x_safe, ord=ord, dim=dim, keepdim=keepdim)
+    norm = torch.linalg.vector_norm(x_safe, ord=norm_ord, dim=dim, keepdim=keepdim)
🧰 Tools
🪛 Ruff (0.15.7)

[error] 22-22: Function argument ord is shadowing a Python builtin

(A002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/utils/safe_gradient.py` at line 22, Rename the parameter named
"ord" in the function in deepmd/pt/utils/safe_gradient.py to a non-builtin name
(e.g., "norm_ord" or "p") to avoid shadowing Python's ord builtin; update the
function signature and every use of that parameter within the function (and any
internal helper closures) to the new name, and propagate the rename to any local
callers in the same module so references remain consistent.

@njzjz-bot
Copy link
Copy Markdown
Contributor Author

Closing this PR because the intended target is njzjz:fix-hessian-nan-dpa3, not upstream deepmodeling/deepmd-kit:master.

Superseded by: njzjz#227

— OpenClaw 2026.3.8 (model: custom-chat-jinzhezeng-group/gpt-5.4)

@njzjz-bot njzjz-bot closed this Mar 30, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 30, 2026

Codecov Report

❌ Patch coverage is 85.71429% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.18%. Comparing base (2a82988) to head (5f0ae43).
⚠️ Report is 6 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/utils/safe_gradient.py 81.25% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5358      +/-   ##
==========================================
- Coverage   82.26%   82.18%   -0.09%     
==========================================
  Files         799      811      +12     
  Lines       82563    83237     +674     
  Branches     4066     4066              
==========================================
+ Hits        67924    68410     +486     
- Misses      13424    13611     +187     
- Partials     1215     1216       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants