Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.safe_gradient import (
safe_for_norm,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
Expand Down Expand Up @@ -473,9 +476,7 @@ def forward(
sw = sw.masked_fill(~nlist_mask, 0.0)

# get angle nlist (maybe smaller)
a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[
:, :, : self.a_sel
]
a_dist_mask = (safe_for_norm(diff, dim=-1) < self.a_rcut)[:, :, : self.a_sel]
a_nlist = nlist[:, :, : self.a_sel]
a_nlist = torch.where(a_dist_mask, a_nlist, -1)
_, a_diff, a_sw = prod_env_mat(
Expand Down Expand Up @@ -512,11 +513,11 @@ def forward(
edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1)
if self.edge_init_use_dist:
# nb x nloc x nnei x 1
edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True)
edge_input = safe_for_norm(diff, dim=-1, keepdim=True)

# nf x nloc x a_nnei x 3
normalized_diff_i = a_diff / (
torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6
safe_for_norm(a_diff, dim=-1, keepdim=True) + 1e-6
)
# nf x nloc x 3 x a_nnei
normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.safe_gradient import (
safe_for_norm,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
Expand Down Expand Up @@ -446,7 +449,7 @@ def forward(
if not self.direct_dist:
g2, h2 = torch.split(dmatrix, [1, 3], dim=-1)
else:
g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff
g2, h2 = safe_for_norm(diff, dim=-1, keepdim=True), diff
g2 = g2 / self.rcut
h2 = h2 / self.rcut
# nb x nloc x nnei x ng2
Expand Down
36 changes: 36 additions & 0 deletions deepmd/pt/utils/safe_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Safe versions of some functions that have problematic gradients.

Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
for more information.
"""

import torch


def safe_for_sqrt(x: torch.Tensor) -> torch.Tensor:
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
mask = x > 0.0
x_safe = torch.where(mask, x, torch.ones_like(x))
return torch.where(mask, torch.sqrt(x_safe), torch.zeros_like(x))


def safe_for_norm(
x: torch.Tensor,
dim: int | None = None,
keepdim: bool = False,
ord: float = 2.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
ruff check deepmd/pt/utils/safe_gradient.py
ruff format --check deepmd/pt/utils/safe_gradient.py

Repository: deepmodeling/deepmd-kit

Length of output: 578


Rename ord argument to avoid Ruff A002 error.

Line 22 shadows Python's builtin ord, which Ruff flags as A002. This will cause CI to fail per the coding guidelines.

🔧 Proposed fix
 def safe_for_norm(
     x: torch.Tensor,
     dim: int | None = None,
     keepdim: bool = False,
-    ord: float = 2.0,
+    norm_ord: float = 2.0,
 ) -> torch.Tensor:
@@
-        norm = torch.linalg.vector_norm(x_safe, ord=ord)
+        norm = torch.linalg.vector_norm(x_safe, ord=norm_ord)
@@
-    norm = torch.linalg.vector_norm(x_safe, ord=ord, dim=dim, keepdim=keepdim)
+    norm = torch.linalg.vector_norm(x_safe, ord=norm_ord, dim=dim, keepdim=keepdim)
🧰 Tools
🪛 Ruff (0.15.7)

[error] 22-22: Function argument ord is shadowing a Python builtin

(A002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/utils/safe_gradient.py` at line 22, Rename the parameter named
"ord" in the function in deepmd/pt/utils/safe_gradient.py to a non-builtin name
(e.g., "norm_ord" or "p") to avoid shadowing Python's ord builtin; update the
function signature and every use of that parameter within the function (and any
internal helper closures) to the new name, and propagate the rename to any local
callers in the same module so references remain consistent.

) -> torch.Tensor:
"""Safe version of vector_norm that has a gradient of 0 at x = 0."""
if dim is None:
mask = torch.sum(torch.square(x)) > 0
x_safe = torch.where(mask, x, torch.ones_like(x))
norm = torch.linalg.vector_norm(x_safe, ord=ord)
return torch.where(mask, norm, torch.zeros_like(norm))

mask = torch.sum(torch.square(x), dim=(dim,), keepdim=True) > 0
mask_out = mask if keepdim else mask.squeeze(dim)

x_safe = torch.where(mask, x, torch.ones_like(x))
norm = torch.linalg.vector_norm(x_safe, ord=ord, dim=dim, keepdim=keepdim)
return torch.where(mask_out, norm, torch.zeros_like(norm))
64 changes: 64 additions & 0 deletions source/tests/pt/model/test_dpa_hessian_finite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import unittest

import numpy as np
import torch

from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)

from ...seed import (
GLOBAL_SEED,
)
from .test_permutation import (
model_dpa2,
model_dpa3,
)

dtype = torch.float64


class TestDPAHessianFinite(unittest.TestCase):
def _build_inputs(self):
natoms = 5
cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE)
generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED)
coord = 3.0 * torch.rand(
[1, natoms, 3], dtype=dtype, device=env.DEVICE, generator=generator
)
atype = torch.tensor([[0, 0, 0, 1, 1]], dtype=torch.int64, device=env.DEVICE)
return coord.view(1, natoms * 3), atype, cell.view(1, 9)

def _assert_hessian_finite(self, model_params):
model = get_model(copy.deepcopy(model_params)).to(env.DEVICE)
model.enable_hessian()
model.requires_hessian("energy")
coord, atype, cell = self._build_inputs()
ret = model.forward_common(coord, atype, box=cell)
hessian = to_numpy_array(ret["energy_derv_r_derv_r"])
self.assertTrue(np.isfinite(hessian).all())

def test_dpa2_direct_dist_hessian_is_finite(self):
model_params = copy.deepcopy(model_dpa2)
model_params["descriptor"]["repformer"]["direct_dist"] = True
model_params["hessian_mode"] = True
self._assert_hessian_finite(model_params)

def test_dpa3_hessian_is_finite(self):
model_params = copy.deepcopy(model_dpa3)
model_params["descriptor"]["precision"] = "float64"
model_params["fitting_net"]["precision"] = "float64"
model_params["hessian_mode"] = True
self._assert_hessian_finite(model_params)


if __name__ == "__main__":
unittest.main()
Loading