Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ def get_has_efield(self) -> bool:
"""Check if the model has efield."""
return False

def get_has_spin(self) -> bool:
"""Check if the model has spin atom types."""
return hasattr(self.dp, "spin")

def get_use_spin(self) -> list[bool]:
"""Get the per-type spin usage of this model."""
if hasattr(self.dp, "spin"):
return self.dp.spin.use_spin.tolist()
return []

def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model."""
return 0
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@
model_type = data.get("type", "standard")
if model_type == "standard":
model_type = data.get("fitting", {}).get("type", "ener")
if model_type == "spin_ener":
# SpinModel is not a BaseModel subclass and cannot be
# registered via the plugin registry. Dispatch directly.
from deepmd.dpmodel.model.spin_model import (
SpinModel,
)

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.dpmodel.model.spin_model
begins an import cycle.
Comment thread
wanghan-iapcm marked this conversation as resolved.
Dismissed

return SpinModel.deserialize(data)
return cls.get_class_by_type(model_type).deserialize(data)
raise NotImplementedError(f"Not implemented in class {cls.__name__}")

Expand Down
7 changes: 5 additions & 2 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,12 +549,15 @@ def __getattr__(self, name: str) -> Any:

def serialize(self) -> dict:
return {
"type": "spin_ener",
"backbone_model": self.backbone_model.serialize(),
"spin": self.spin.serialize(),
}

@classmethod
def deserialize(cls, data: dict) -> "SpinModel":
data = data.copy()
data.pop("type", None)
backbone_model_obj = make_model(
DPAtomicModel, T_Bases=(NativeOP, BaseModel)
).deserialize(data["backbone_model"])
Expand Down Expand Up @@ -646,7 +649,7 @@ def call_common(
) = self.process_spin_output(
atype,
model_ret[f"{var_name}_derv_c"],
add_mag=False,
add_mag=True,
virtual_scale=False,
)
# Always compute mask_mag from atom types (even when forces are unavailable)
Expand Down Expand Up @@ -823,7 +826,7 @@ def call_common_lower(
extended_atype,
model_ret[f"{var_name}_derv_c"],
nloc,
add_mag=False,
add_mag=True,
virtual_scale=False,
)
# Always compute mask_mag from atom types (even when forces are unavailable)
Expand Down
35 changes: 35 additions & 0 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_deriv_name_mag,
get_hessian_name,
get_reduce_name,
)
Expand Down Expand Up @@ -128,6 +129,21 @@ def communicate_extended_output(
model_ret[kk_derv_r],
)
new_ret[kk_derv_r] = force
if vdef.magnetic:
kk_derv_r_mag = get_deriv_name_mag(kk)[0]
if model_ret.get(kk_derv_r_mag) is not None:
force_mag = xp.zeros(
vldims + derv_r_ext_dims,
dtype=vv.dtype,
device=device,
)
force_mag = xp_scatter_sum(
force_mag,
1,
mapping,
model_ret[kk_derv_r_mag],
)
new_ret[kk_derv_r_mag] = force_mag
else:
# name holders
new_ret[kk_derv_r] = None
Expand Down Expand Up @@ -235,10 +251,29 @@ def communicate_extended_output(
)
new_ret[kk_derv_c] = virial
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
if vdef.magnetic:
kk_derv_c_mag = get_deriv_name_mag(kk)[1]
if model_ret.get(kk_derv_c_mag) is not None:
virial_mag = xp.zeros(
vldims + derv_c_ext_dims,
dtype=vv.dtype,
device=device,
)
virial_mag = xp_scatter_sum(
virial_mag,
1,
mapping,
model_ret[kk_derv_c_mag],
)
new_ret[kk_derv_c_mag] = virial_mag
else:
new_ret[kk_derv_c] = None
new_ret[kk_derv_c + "_redu"] = None
if not do_atomic_virial:
# pop atomic virial, because it is not correctly calculated.
new_ret.pop(kk_derv_c)
# Slice mask_mag from extended to local atoms
if "mask_mag" in model_ret:
nloc = new_ret[next(iter(model_output_def.keys_outp()))].shape[1]
new_ret["mask_mag"] = model_ret["mask_mag"][:, :nloc]
return new_ret
23 changes: 23 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ def get_has_spin(self) -> bool:
"""Check if the model has spin atom types."""
return False

def get_use_spin(self) -> list[bool]:
"""Get the per-type spin usage of this model.

Returns
-------
list[bool]
A list of bool indicating whether each atom type uses spin.
Empty list if the model does not have spin.
"""
return []

def get_has_hessian(self) -> bool:
"""Check if the model has hessian."""
return False
Expand Down Expand Up @@ -705,6 +716,18 @@ def has_spin(self) -> bool:
"""Check if the model has spin."""
return self.deep_eval.get_has_spin()

@property
def use_spin(self) -> list[bool]:
"""Get the per-type spin usage of this model.

Returns
-------
list[bool]
A list of bool indicating whether each atom type uses spin.
Empty list if the model does not have spin.
"""
return self.deep_eval.get_use_spin()

@property
def has_hessian(self) -> bool:
"""Check if the model has hessian."""
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ def get_has_spin(self) -> bool:
"""Check if the model has spin atom types."""
return self._has_spin

def get_use_spin(self) -> list[bool]:
"""Get the per-type spin usage of this model."""
if self._has_spin:
model = self.dp.model["Default"]
return model.spin.use_spin.tolist()
return []
Comment thread
wanghan-iapcm marked this conversation as resolved.

def get_has_hessian(self) -> bool:
"""Check if the model has hessian."""
return self._has_hessian
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ def get_has_spin(self) -> bool:
"""Check if the model has spin atom types."""
return self._has_spin

def get_use_spin(self) -> list[bool]:
"""Get the per-type spin usage of this model."""
if self._has_spin:
model = self.dp.model["Default"]
return model.spin.use_spin.tolist()
return []

def get_has_hessian(self) -> bool:
"""Check if the model has hessian."""
return self._has_hessian
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,12 +637,15 @@ def forward_common_lower(

def serialize(self) -> dict:
return {
"type": "spin_ener",
"backbone_model": self.backbone_model.serialize(),
"spin": self.spin.serialize(),
}

@classmethod
def deserialize(cls, data: dict[str, Any]) -> "SpinModel":
data = data.copy()
data.pop("type", None)
backbone_model_obj = make_model(DPAtomicModel).deserialize(
data["backbone_model"]
)
Expand Down
10 changes: 9 additions & 1 deletion deepmd/pt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
"""
if not model_file.endswith(".pth"):
raise ValueError("PyTorch backend only supports converting .pth file")
model = BaseModel.deserialize(data["model"])
model_data = data["model"]
if model_data.get("type") == "spin_ener":
from deepmd.pt.model.model.spin_model import (
SpinEnergyModel,
)

model = SpinEnergyModel.deserialize(model_data)
else:
model = BaseModel.deserialize(model_data)
# JIT will happy in this way...
model.model_def_script = json.dumps(data["model_def_script"])
if "min_nbor_dist" in data.get("@variables", {}):
Expand Down
Loading
Loading