diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 7d980a60d..5aa287b0c 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -2,7 +2,7 @@ import warnings from html import escape -from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Optional +from typing import TYPE_CHECKING, Any, Callable, Hashable, Literal, Mapping, Optional from warnings import warn import cartopy.crs as ccrs @@ -14,6 +14,7 @@ from xarray.core.utils import UncachedAccessor import uxarray +from uxarray.constants import GRID_DIMS from uxarray.core.aggregation import _uxda_grid_aggregate from uxarray.core.gradient import ( _calculate_edge_face_difference, @@ -1760,6 +1761,7 @@ def isel( ValueError If more than one grid dimension is selected and `ignore_grid=False`. """ + from uxarray.core.dataarray import UxDataArray from uxarray.core.utils import _validate_indexers indexers, grid_dims = _validate_indexers( @@ -1959,6 +1961,120 @@ def get_dual(self): return uxda + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ) -> UxDataArray: + """Apply neighborhood filter + Parameters: + ----------- + func: Callable, default=np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + Returns: + -------- + destination_data : np.ndarray + Filtered data. + """ + + if self._face_centered(): + data_mapping = "face centers" + elif self._node_centered(): + data_mapping = "nodes" + elif self._edge_centered(): + data_mapping = "edge centers" + else: + raise ValueError( + "Data_mapping is not face, node, or edge. Could not define data_mapping." + ) + + # reconstruct because the cached tree could be built from + # face centers, edge centers or nodes. + tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True) + + coordinate_system = tree.coordinate_system + + if coordinate_system == "spherical": + if data_mapping == "nodes": + lon, lat = ( + self.uxgrid.node_lon.values, + self.uxgrid.node_lat.values, + ) + elif data_mapping == "face centers": + lon, lat = ( + self.uxgrid.face_lon.values, + self.uxgrid.face_lat.values, + ) + elif data_mapping == "edge centers": + lon, lat = ( + self.uxgrid.edge_lon.values, + self.uxgrid.edge_lat.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.vstack((lon, lat)).T + + elif coordinate_system == "cartesian": + if data_mapping == "nodes": + x, y, z = ( + self.uxgrid.node_x.values, + self.uxgrid.node_y.values, + self.uxgrid.node_z.values, + ) + elif data_mapping == "face centers": + x, y, z = ( + self.uxgrid.face_x.values, + self.uxgrid.face_y.values, + self.uxgrid.face_z.values, + ) + elif data_mapping == "edge centers": + x, y, z = ( + self.uxgrid.edge_x.values, + self.uxgrid.edge_y.values, + self.uxgrid.edge_z.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.vstack((x, y, z)).T + + else: + raise ValueError( + f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}" + ) + + neighbor_indices = tree.query_radius(dest_coords, r=r) + + # Construct numpy array for filtered variable. + destination_data = np.empty(self.data.shape) + + # Assert last dimension is a GRID dimension. + assert self.dims[-1] in GRID_DIMS, ( + f"expected last dimension of uxDataArray {self.data.dims[-1]} " + f"to be one of {GRID_DIMS}" + ) + # Apply function to indices on last axis. + for i, idx in enumerate(neighbor_indices): + if len(idx): + destination_data[..., i] = func(self.data[..., idx]) + + # Construct UxDataArray for filtered variable. + uxda_filter = self._copy() + + uxda_filter.data = destination_data + + return uxda_filter + def __getattribute__(self, name): """Intercept accessor method calls to return Ux-aware accessors.""" # Lazy import to avoid circular imports diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 3dc1feffc..9e723943e 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -3,7 +3,7 @@ import os import sys from html import escape -from typing import IO, Any, Mapping +from typing import IO, Any, Callable, Mapping from warnings import warn import numpy as np @@ -13,6 +13,7 @@ from xarray.core.utils import UncachedAccessor import uxarray +from uxarray.constants import GRID_DIMS from uxarray.core.dataarray import UxDataArray from uxarray.core.utils import _map_dims_to_ugrid, _open_dataset_with_fallback from uxarray.formatting_html import dataset_repr @@ -610,6 +611,41 @@ def to_array(self) -> UxDataArray: xarr = super().to_array() return UxDataArray(xarr, uxgrid=self.uxgrid) + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ): + """Neighborhood function implementation for ``UxDataset``. + Parameters + --------- + func : Callable = np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + """ + + destination_uxds = self._copy() + # Loop through uxDataArrays in uxDataset + for var_name in self.data_vars: + uxda = self[var_name] + + # Skip if uxDataArray has no GRID dimension. + grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS] + if len(grid_dims) == 0: + continue + + # Put GRID dimension last for UxDataArray.neighborhood_filter. + remember_dim_order = uxda.dims + uxda = uxda.transpose(..., grid_dims[0]) + # Filter uxDataArray. + uxda = uxda.neighborhood_filter(func, r) + # Restore old dimension order. + destination_uxds[var_name] = uxda.transpose(*remember_dim_order) + + return destination_uxds + def to_xarray(self, grid_format: str = "UGRID") -> xr.Dataset: """ Converts a ``ux.UXDataset`` to a ``xr.Dataset``.