Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
## Latest Changes

### v0.6.0 (2025-02-23)
OpenEquivariance v0.6.0 brings long-needed improvements to the
PyTorch frontend. We strongly encourage all users to upgrade
to PyTorch 2.10 and OEQ v0.6.0.

**Added**:
- OpenEquivariance triggers a build of the CUDA extension module
at `pip` install time and will use this precompiled extension if
the user has PyTorch >=2.10 installed. If PyTorch <2.10 is installed,
the JIT-compiled extension is used instead.
- PyTorch ABI support for C++ backend, using new features in PyTorch
2.10 to support stable, forward-compatible ahead-of-time
extensions.
- Dropped support for TorchBind classes and a new kernel cache in its
place, which greatly improves flexibility for automatic mixed precision
and AOTI compilation. An inference test in C++ is included.
- `openequivariance_extjax` has a version number that synchronizes with
the main `openequivariance` package; ensure the two packages stay in sync.

**Fixed**:
- `torch.to()` is now called when either `TensorProduct`
or `TensorProductConv` is a submodule of another PyTorch
module.


### v0.5.4 (2025-02-01)
Improvements to JAX frontend.

Expand Down
10 changes: 10 additions & 0 deletions openequivariance/openequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from pathlib import Path
import warnings
from importlib.metadata import version

from openequivariance.core.e3nn_lite import (
Expand Down Expand Up @@ -80,6 +81,15 @@ def torch_ext_so_path():
try:
import openequivariance_extjax
import openequivariance.jax as jax

# TODO-someday: enable
# extjax_version = version("openequivariance_extjax")
# if extjax_version != __version__:
# warnings.warn(
# f"openequivariance_extjax version {extjax_version} does not match "
# f"openequivariance version {__version__}. Ensure both versions match."
# )

except Exception as e:
error = e

Expand Down
29 changes: 27 additions & 2 deletions openequivariance/openequivariance/_torch/TensorProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
from openequivariance import TPProblem
from openequivariance._torch import extlib
import torch
from openequivariance.core.utils import torch_to_oeq_dtype
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
from openequivariance.benchmark.logging_utils import getLogger
from openequivariance._torch.utils import reorder_torch, string_to_tensor
from openequivariance._torch.utils import (
reorder_torch,
string_to_tensor,
enum_to_torch_dtype,
)
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin

import numpy as np
Expand Down Expand Up @@ -66,6 +70,27 @@ def to(self, *args, **kwargs):
torch.nn.Module.to(self, *args, **kwargs)
return self

def _apply(self, fn, recurse=True):
if getattr(self, "_applying", False):
return super()._apply(fn, recurse)

problem: TPProblem = self.input_args["problem"]
irrep_dtype = problem.irrep_dtype

if irrep_dtype in dtype_to_enum:
irrep_dtype = dtype_to_enum[irrep_dtype]

current_dtype = enum_to_torch_dtype[irrep_dtype]
dummy = torch.tensor(0.0, dtype=current_dtype)
result = fn(dummy)

if result.dtype != current_dtype:
self._applying = True
self.to(result.dtype)
self._applying = False

return super()._apply(fn, recurse)

def __getstate__(self):
return self.input_args

Expand Down
24 changes: 23 additions & 1 deletion openequivariance/openequivariance/_torch/TensorProductConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from openequivariance.core.LoopUnrollConv import LoopUnrollConv
from openequivariance._torch.TensorProduct import TensorProduct
from openequivariance import TPProblem
from openequivariance.core.utils import torch_to_oeq_dtype
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
from openequivariance._torch.utils import (
reorder_torch,
string_to_tensor,
enum_to_torch_dtype,
)

from openequivariance.benchmark.logging_utils import getLogger
Expand Down Expand Up @@ -109,6 +110,27 @@ def to(self, *args, **kwargs):
torch.nn.Module.to(self, *args, **kwargs)
return self

def _apply(self, fn, recurse=True):
if getattr(self, "_applying", False):
return super()._apply(fn, recurse)

problem: TPProblem = self.input_args["problem"]
irrep_dtype = problem.irrep_dtype

if irrep_dtype in dtype_to_enum:
irrep_dtype = dtype_to_enum[irrep_dtype]

current_dtype = enum_to_torch_dtype[irrep_dtype]
dummy = torch.tensor(0.0, dtype=current_dtype)
result = fn(dummy)

if result.dtype != current_dtype:
self._applying = True
self.to(result.dtype)
self._applying = False

return super()._apply(fn, recurse)

def __getstate__(self):
return self.input_args

Expand Down
95 changes: 87 additions & 8 deletions tests/batch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import torch


@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module")
def dtype(request):
return request.param


class TPCorrectness:
def thresh(self, direction):
return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction]
Expand All @@ -31,18 +36,10 @@ def check_result(self, result, fieldname):
f"{fieldname} observed error={error:.5f} >= {thresh}"
)

@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class")
def dtype(self, request):
return request.param

@pytest.fixture(scope="class")
def extra_tp_constructor_args(self):
return {}

@pytest.fixture(scope="class")
def with_jax(self, request):
return request.config.getoption("--jax")

@pytest.fixture(scope="class")
def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
cls = oeq.TensorProduct
Expand Down Expand Up @@ -274,3 +271,85 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
}
tp.to(switch_map[problem.irrep_dtype])
return tp, tp.config


class TestTorchToSubmodule:
"""Test that TensorProduct works correctly as a submodule when parent's .to() is called"""

@pytest.fixture(scope="class")
def parent_module_and_problem(self, dtype, with_jax):
if with_jax:
pytest.skip("N/A for JAX")

problem = mace_problems()[0].clone()
problem.irrep_dtype, problem.weight_dtype = dtype, dtype

class ParentModule(torch.nn.Module):
def __init__(self, problem):
super().__init__()
self.tp = oeq.TensorProduct(problem)

def forward(self, x, y, w):
return self.tp(x, y, w)

parent = ParentModule(problem)
return parent, problem

def _problem_dtype(self, problem):
return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64

def _make_inputs(self, problem, batch_size, rng, dtype, device):
in1 = torch.tensor(
rng.uniform(size=(batch_size, problem.irreps_in1.dim)),
dtype=dtype,
device=device,
)
in2 = torch.tensor(
rng.uniform(size=(batch_size, problem.irreps_in2.dim)),
dtype=dtype,
device=device,
)
weights_size = (
(problem.weight_numel,)
if problem.shared_weights
else (batch_size, problem.weight_numel)
)
weights = torch.tensor(
rng.uniform(size=weights_size),
dtype=dtype,
device=device,
)
return in1, in2, weights

def test_submodule_dtype_conversion(self, parent_module_and_problem):
"""Test that calling .to() on parent module properly converts TensorProduct submodule"""
parent, problem = parent_module_and_problem

batch_size = 10
rng = np.random.default_rng(12345)
device = "cuda"
input_dtype = self._problem_dtype(problem)
in1, in2, weights = self._make_inputs(
problem, batch_size, rng, input_dtype, device
)

output1 = parent(in1, in2, weights)
assert output1.dtype == in1.dtype, (
f"Expected output dtype {in1.dtype}, got {output1.dtype}"
)

switch_map = {
np.float32: torch.float64,
np.float64: torch.float32,
}
target_dtype = switch_map[problem.irrep_dtype]
parent.to(target_dtype)

in1_new, in2_new, weights_new = self._make_inputs(
problem, batch_size, rng, target_dtype, device
)

output2 = parent(in1_new, in2_new, weights_new)
assert output2.dtype == target_dtype, (
f"Expected output dtype {target_dtype}, got {output2.dtype}"
)
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pytest

os.environ["JAX_ENABLE_X64"] = "True"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
Expand All @@ -12,3 +13,8 @@ def pytest_addoption(parser):
default=False,
help="Test the JAX frontend instead of PyTorch",
)


@pytest.fixture(scope="session")
def with_jax(request):
return request.config.getoption("--jax")
Loading