Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
fail-fast: false
matrix:
include:
- {os: windows-latest, python: "3.11", dask-version: "2025.2.0", name: "Dask 2025.2.0"}
- {os: windows-latest, python: "3.11", dask-version: "2025.12.0", name: "Dask 2025.12.0"}
- {os: windows-latest, python: "3.13", dask-version: "latest", name: "Dask latest"}
- {os: ubuntu-latest, python: "3.11", dask-version: "latest", name: "Dask latest"}
- {os: ubuntu-latest, python: "3.13", dask-version: "latest", name: "Dask latest"}
Expand Down
17 changes: 17 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ SpatialData is a data framework that comprises a FAIR storage format and a colle

Please see our publication {cite}`marconatoSpatialDataOpenUniversal2024` for citation and to learn more.

:::{note}
With dask >= 2025.2.0, users can get an error as described in [#1077](https://github.com/scverse/spatialdata/issues/1064). While we tried implementing fixes in SpatialData, it can be that
users perform operations on the `Points` data themselves and get this error. In order to prevent it, users can use a context manager we created.

```python
from spatialdata import disable_dask_tune_optimization
import contextlib
...

with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext():
<your operation on points dask dataframe>
```

This will disable dask graph optimization if the dataframe has more than 1 partition and otherwise keep it enabled. This solves
the problem discussed in this [dask issue](https://github.com/dask/dask/issues/12193). We are looking into an upstream fix.
:::

[//]: # "numfocus-fiscal-sponsor-attribution"

spatialdata is part of the scverse® project ([website](https://scverse.org), [governance](https://scverse.org/about/roles)) and is fiscally sponsored by [NumFOCUS](https://numfocus.org/).
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"annsel>=0.1.2",
"click",
"dask-image",
"dask>=2025.2.0,<2026.1.2",
"dask>=2025.12.0,<2026.1.2",
"distributed<2026.1.2",
"datashader",
"fsspec[s3,http]",
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"transformations",
"datasets",
"dataloader",
"disable_dask_tune_optimization",
"concatenate",
"rasterize",
"rasterize_bins",
Expand Down Expand Up @@ -72,5 +73,5 @@
from spatialdata._io._utils import get_dask_backing_files
from spatialdata._io.format import SpatialDataFormatType
from spatialdata._io.io_zarr import read_zarr
from spatialdata._utils import get_pyramid_levels, unpad_raster
from spatialdata._utils import disable_dask_tune_optimization, get_pyramid_levels, unpad_raster
from spatialdata.config import settings
13 changes: 11 additions & 2 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import itertools
import warnings
from functools import singledispatch
Expand All @@ -17,6 +18,7 @@

from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata._utils import disable_dask_tune_optimization
from spatialdata.models import SpatialElement, get_axes_names, get_model
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channel_names
from spatialdata.transformations._utils import _get_scale, compute_coordinates, scale_radii
Expand Down Expand Up @@ -439,8 +441,15 @@ def _(
)
axes = get_axes_names(data)
arrays = []
for ax in axes:
arrays.append(data[ax].to_dask_array(lengths=True).reshape(-1, 1))

# Workaround to prevent partition collaps and missing dependency problem for now.
with disable_dask_tune_optimization() if data.npartitions > 1 else contextlib.nullcontext():
for ax in axes:
# TODO We have to pass on the lengths explicitly as automatic determination with dask graph optimization
# leads to collaps of the partitions. However this causes a missing dependency problem, which for now is
# prevented by setting the optimization to False when performing this operation.
arrays.append(data[ax].to_dask_array(lengths=[len(part) for part in data.partitions]).reshape(-1, 1))

xdata = DataArray(da.concatenate(arrays, axis=1), coords={"points": range(len(data)), "dim": list(axes)})
xtransformed = transformation._transform_coordinates(xdata)
transformed = data.drop(columns=list(axes)).copy()
Expand Down
13 changes: 13 additions & 0 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import re
import warnings
from collections.abc import Callable, Generator
from contextlib import contextmanager
from itertools import islice
from typing import Any, TypeVar

import numpy as np
import pandas as pd
from anndata import AnnData
from dask import array as da
from dask import config
from dask.array import Array as DaskArray
from xarray import DataArray, Dataset, DataTree

Expand All @@ -20,6 +22,17 @@
RT = TypeVar("RT")


@contextmanager
def disable_dask_tune_optimization() -> Generator[None, None, None]:
"""Prevent dask graph optimization when performing operations on dask dataframes with npartition > 1."""
old_setting = config.config["optimization"]["tune"]["active"]
config.set({"optimization.tune.active": False})
try:
yield
finally:
config.set({"optimization.tune.active": old_setting})


def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike:
if isinstance(array, list):
array = np.array(array)
Expand Down
42 changes: 41 additions & 1 deletion tests/core/operations/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import contextlib
import math
import tempfile
from pathlib import Path

import numpy as np
import pytest
from dask import config
from geopandas.testing import geom_almost_equals
from xarray import DataArray, DataTree

from spatialdata import transform
from spatialdata._core.data_extent import are_extents_equal, get_extent
from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import unpad_raster
from spatialdata._utils import disable_dask_tune_optimization, unpad_raster
from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names
from spatialdata.transformations.operations import (
align_elements_using_landmarks,
Expand Down Expand Up @@ -586,6 +588,44 @@ def test_transform_elements_and_entire_spatial_data_object(full_sdata: SpatialDa
_ = full_sdata.transform_to_coordinate_system("my_space", maintain_positioning=maintain_positioning)


def test_transform_points_with_multiple_partitions(full_sdata: SpatialData, tmp_path: str):
tmpdir = Path(tmp_path) / "tmp.zarr"
points_memory = full_sdata["points_0"].compute()
full_sdata["points_0"] = PointsModel.parse(
full_sdata["points_0"].repartition(npartitions=4),
transformations={"global": get_transformation(full_sdata["points_0"])},
)
assert points_memory.equals(full_sdata["points_0"].compute())

full_sdata.write(tmpdir)

full_sdata = SpatialData.read(tmpdir)

# This just needs to run without error
data = transform(full_sdata["points_0"], to_coordinate_system="global")

# test that data still can be computed
data.compute()


@pytest.mark.parametrize(
"tune,partition",
[
(True, None),
(False, 4),
],
)
def test_dask_tune_contextmanager(full_sdata: SpatialData, partition: int | None, tune: bool):
if partition:
full_sdata["points_0"] = PointsModel.parse(
full_sdata["points_0"].repartition(npartitions=4),
transformations={"global": get_transformation(full_sdata["points_0"])},
)

with disable_dask_tune_optimization() if full_sdata["points_0"].npartitions > 1 else contextlib.nullcontext():
assert config.config["optimization"]["tune"]["active"] is tune


@pytest.mark.parametrize("maintain_positioning", [True, False])
def test_transform_elements_and_entire_spatial_data_object_multi_hop(
full_sdata: SpatialData, maintain_positioning: bool
Expand Down