Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
75d832e
Neighborhood filter
ahijevyc Sep 7, 2024
b850bc4
Merge remote-tracking branch 'upstream/main' into neighborhood_filter
ahijevyc Sep 7, 2024
8ec0193
ruff recommendations
ahijevyc Sep 9, 2024
5605949
added Callable to Type checking
ahijevyc Sep 9, 2024
70ba961
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Sep 9, 2024
0c7bc1e
np.vstack().T faster than np.c
ahijevyc Sep 9, 2024
d6d8a33
Fix some comments
ahijevyc Sep 9, 2024
47b9cda
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Sep 17, 2024
f75db0d
Merge branch 'main' into ahijevyc/neighborhood_filter
aaronzedwick Oct 1, 2024
bddd2fa
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Oct 29, 2024
a0b6361
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Mar 17, 2025
6c59af7
missing imports
ahijevyc Mar 17, 2025
67d0c11
Merge branch 'main' into ahijevyc/neighborhood_filter
philipc2 Mar 18, 2025
8979745
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Mar 19, 2025
a8875cd
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc May 28, 2025
f4af498
Merge branch 'UXARRAY:main' into ahijevyc/neighborhood_filter
ahijevyc Aug 4, 2025
45407aa
Merge branch 'main' into ahijevyc/neighborhood_filter
erogluorhan Sep 3, 2025
9fa100c
Merge branch 'main' into ahijevyc/neighborhood_filter
ahijevyc Jan 21, 2026
56e7821
Merge branch 'main' into ahijevyc/neighborhood_filter
erogluorhan Feb 26, 2026
14c37b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2026
a37b88f
Update dataset.py to address pre-commit errors
erogluorhan Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from html import escape
from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional
from typing import TYPE_CHECKING, Any, Callable, Hashable, Literal, Mapping, Optional
from warnings import warn

import cartopy.crs as ccrs
Expand All @@ -14,6 +14,7 @@
from xarray.core.utils import UncachedAccessor

import uxarray
from uxarray.constants import GRID_DIMS
from uxarray.core.aggregation import _uxda_grid_aggregate
from uxarray.core.gradient import (
_calculate_edge_face_difference,
Expand Down Expand Up @@ -1760,6 +1761,7 @@ def isel(
ValueError
If more than one grid dimension is selected and `ignore_grid=False`.
"""
from uxarray.core.dataarray import UxDataArray
from uxarray.core.utils import _validate_indexers

indexers, grid_dims = _validate_indexers(
Expand Down Expand Up @@ -1959,6 +1961,120 @@ def get_dual(self):

return uxda

def neighborhood_filter(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation looks great! May we move the bulk of the logic into the uxarray.grid.neighbors module and call that helper from here?

We can keep the data-mapping checks here, and anything related to constructing and returining the final data array but the bulk of the computations would go inside a helper in the module mentioned above.

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to think about how to do that, but I am happy to defer to you.

self,
func: Callable = np.mean,
r: float = 1.0,
) -> UxDataArray:
"""Apply neighborhood filter
Parameters:
-----------
func: Callable, default=np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood. For spherical coordinates, the radius is in units of degrees,
and for cartesian coordinates, the radius is in meters.
Returns:
--------
destination_data : np.ndarray
Filtered data.
"""

if self._face_centered():
data_mapping = "face centers"
elif self._node_centered():
data_mapping = "nodes"
elif self._edge_centered():
data_mapping = "edge centers"
else:
raise ValueError(
"Data_mapping is not face, node, or edge. Could not define data_mapping."
)

# reconstruct because the cached tree could be built from
# face centers, edge centers or nodes.
tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True)
Comment on lines +1995 to +1996
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaronzedwick

We should probably fix this logic in get_ball_tree(), since we shouldn't need to manually set reconstruct=False

        if self._ball_tree is None or reconstruct:
            self._ball_tree = BallTree(
                self,
                coordinates=coordinates,
                distance_metric=distance_metric,
                coordinate_system=coordinate_system,
                reconstruct=reconstruct,
            )
        else:
            if coordinates != self._ball_tree._coordinates:
                self._ball_tree.coordinates = coordinates

The coordinates != self._ball_tree._coordinates check should be included in the first if

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. So, move the coordinates check to the if-clause like this?

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )

What if the coordinate_system is different? Would that also require a newly constructed tree?

Copy link
Collaborator Author

@ahijevyc ahijevyc Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever logic is fixed in Grid.get_ball_tree should also be applied to Grid.get_kdtree.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking coordinate system also (coordinate_system is not a hidden variable of _ball_tree; it has no underscore):

                if (
                    self._ball_tree is None
                    or coordinates != self._ball_tree._coordinates
                    or coordinate_system != self._ball_tree.coordinate_system
                    or reconstruct
                ):

                    self._ball_tree = BallTree(
                        self,
                        coordinates=coordinates,
                        distance_metric=distance_metric,
                        coordinate_system=coordinate_system,
                        reconstruct=reconstruct,
                    )


coordinate_system = tree.coordinate_system

if coordinate_system == "spherical":
if data_mapping == "nodes":
lon, lat = (
self.uxgrid.node_lon.values,
self.uxgrid.node_lat.values,
)
elif data_mapping == "face centers":
lon, lat = (
self.uxgrid.face_lon.values,
self.uxgrid.face_lat.values,
)
elif data_mapping == "edge centers":
lon, lat = (
self.uxgrid.edge_lon.values,
self.uxgrid.edge_lat.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)

dest_coords = np.vstack((lon, lat)).T

elif coordinate_system == "cartesian":
if data_mapping == "nodes":
x, y, z = (
self.uxgrid.node_x.values,
self.uxgrid.node_y.values,
self.uxgrid.node_z.values,
)
elif data_mapping == "face centers":
x, y, z = (
self.uxgrid.face_x.values,
self.uxgrid.face_y.values,
self.uxgrid.face_z.values,
)
elif data_mapping == "edge centers":
x, y, z = (
self.uxgrid.edge_x.values,
self.uxgrid.edge_y.values,
self.uxgrid.edge_z.values,
)
else:
raise ValueError(
f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', "
f"but received: {data_mapping}"
)
Comment on lines +1997 to +2047
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use #974 's new remap.utils._remap_grid_parse instead of this code block.


dest_coords = np.vstack((x, y, z)).T

else:
raise ValueError(
f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}"
)

neighbor_indices = tree.query_radius(dest_coords, r=r)

# Construct numpy array for filtered variable.
destination_data = np.empty(self.data.shape)

# Assert last dimension is a GRID dimension.
assert self.dims[-1] in GRID_DIMS, (
f"expected last dimension of uxDataArray {self.data.dims[-1]} "
f"to be one of {GRID_DIMS}"
)
# Apply function to indices on last axis.
for i, idx in enumerate(neighbor_indices):
if len(idx):
destination_data[..., i] = func(self.data[..., idx])

# Construct UxDataArray for filtered variable.
uxda_filter = self._copy()

uxda_filter.data = destination_data

return uxda_filter

def __getattribute__(self, name):
"""Intercept accessor method calls to return Ux-aware accessors."""
# Lazy import to avoid circular imports
Expand Down
38 changes: 37 additions & 1 deletion uxarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from html import escape
from typing import IO, Any, Mapping
from typing import IO, Any, Callable, Mapping
from warnings import warn

import numpy as np
Expand All @@ -13,6 +13,7 @@
from xarray.core.utils import UncachedAccessor

import uxarray
from uxarray.constants import GRID_DIMS
from uxarray.core.dataarray import UxDataArray
from uxarray.core.utils import _map_dims_to_ugrid, _open_dataset_with_fallback
from uxarray.formatting_html import dataset_repr
Expand Down Expand Up @@ -610,6 +611,41 @@ def to_array(self) -> UxDataArray:
xarr = super().to_array()
return UxDataArray(xarr, uxgrid=self.uxgrid)

def neighborhood_filter(
self,
func: Callable = np.mean,
r: float = 1.0,
):
"""Neighborhood function implementation for ``UxDataset``.
Parameters
---------
func : Callable = np.mean
Apply this function to neighborhood
r : float, default=1.
Radius of neighborhood. For spherical coordinates, the radius is in units of degrees,
and for cartesian coordinates, the radius is in meters.
"""

destination_uxds = self._copy()
# Loop through uxDataArrays in uxDataset
for var_name in self.data_vars:
uxda = self[var_name]

# Skip if uxDataArray has no GRID dimension.
grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS]
if len(grid_dims) == 0:
continue

# Put GRID dimension last for UxDataArray.neighborhood_filter.
remember_dim_order = uxda.dims
uxda = uxda.transpose(..., grid_dims[0])
# Filter uxDataArray.
uxda = uxda.neighborhood_filter(func, r)
# Restore old dimension order.
destination_uxds[var_name] = uxda.transpose(*remember_dim_order)

return destination_uxds

def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset:
"""
Converts a ``ux.UXDataset`` to a ``xr.Dataset``.
Expand Down
Loading