from __future__ import annotations from collections.abc import Hashable, Mapping, Sequence from typing import Any import pytest import dask import dask.threaded from dask.base import DaskMethodsMixin, dont_optimize, tokenize from dask.context import globalmethod from dask.delayed import Delayed, delayed from dask.typing import ( DaskCollection, HLGDaskCollection, PostComputeCallable, PostPersistCallable, ) try: from IPython.display import DisplayObject except ImportError: DisplayObject = Any da = pytest.importorskip("dask.array") db = pytest.importorskip("dask.bag") dds = pytest.importorskip("dask.datasets") dd = pytest.importorskip("dask.dataframe") def finalize(x: Sequence[Any]) -> Any: return x[0] def get1(dsk: Mapping, keys: Sequence[Hashable] | Hashable, **kwargs: Any) -> Any: return dask.threaded.get(dsk, keys, **kwargs) def get2(dsk: Mapping, keys: Sequence[Hashable] | Hashable, **kwargs: Any) -> Any: return dask.get(dsk, keys, **kwargs) class Inheriting(DaskCollection): def __init__(self, based_on: DaskCollection) -> None: self.based_on = based_on def __dask_graph__(self) -> Mapping: return self.based_on.__dask_graph__() def __dask_keys__(self) -> list[Hashable]: return self.based_on.__dask_keys__() def __dask_postcompute__(self) -> tuple[PostComputeCallable, tuple]: return finalize, () def __dask_postpersist__(self) -> tuple[PostPersistCallable, tuple]: return self.based_on.__dask_postpersist__() def __dask_tokenize__(self) -> Hashable: return tokenize(self.based_on) __dask_scheduler__ = staticmethod(dask.threaded.get) __dask_optimize__ = globalmethod( dont_optimize, key="hlgcollection_optim", falsey=dont_optimize, ) def compute(self, **kwargs) -> Any: return dask.compute(self, **kwargs) def persist(self, **kwargs) -> Inheriting: return Inheriting(self.based_on.persist(**kwargs)) def visualize( self, filename: str = "mydask", format: str | None = None, optimize_graph: bool = False, **kwargs: Any, ) -> DisplayObject | None: return dask.visualize( self, filename=filename, format=format, optimize_graph=optimize_graph, **kwargs, ) class HLGCollection(DaskMethodsMixin): def __init__(self, based_on: HLGDaskCollection) -> None: self.based_on = based_on def __dask_graph__(self) -> Mapping: return self.based_on.__dask_graph__() def __dask_layers__(self) -> Sequence[str]: return self.based_on.__dask_layers__() def __dask_keys__(self) -> list[Hashable]: return self.based_on.__dask_keys__() def __dask_postcompute__(self) -> tuple[PostComputeCallable, tuple]: return finalize, () def __dask_postpersist__(self) -> tuple[PostPersistCallable, tuple]: return self.based_on.__dask_postpersist__() def __dask_tokenize__(self) -> Hashable: return tokenize(self.based_on) __dask_scheduler__ = staticmethod(get1) __dask_optimize__ = globalmethod( dont_optimize, key="hlgcollection_optim", falsey=dont_optimize, ) class NotHLGCollection(DaskMethodsMixin): def __init__(self, based_on: DaskCollection) -> None: self.based_on = based_on def __dask_graph__(self) -> Mapping: return self.based_on.__dask_graph__() def __dask_keys__(self) -> list[Hashable]: return self.based_on.__dask_keys__() def __dask_postcompute__(self) -> tuple[PostComputeCallable, tuple]: return finalize, () def __dask_postpersist__(self) -> tuple[PostPersistCallable, tuple]: return self.based_on.__dask_postpersist__() def __dask_tokenize__(self) -> Hashable: return tokenize(self.based_on) __dask_scheduler__ = staticmethod(get2) __dask_optimize__ = globalmethod( dont_optimize, key="collection_optim", falsey=dont_optimize, ) def increment_(x: int) -> int: return x + 1 increment: Delayed = delayed(increment_) def assert_isinstance(coll: DaskCollection, protocol: Any) -> None: assert isinstance(coll, protocol) @pytest.mark.parametrize("protocol", [DaskCollection, HLGDaskCollection]) def test_isinstance_core(protocol): arr = da.ones(10) bag = db.from_sequence([1, 2, 3, 4, 5], npartitions=2) df = dds.timeseries() dobj = increment(2) assert_isinstance(arr, protocol) assert_isinstance(bag, protocol) assert_isinstance(df, protocol) assert_isinstance(dobj, protocol) def test_isinstance_custom() -> None: a = da.ones(10) hlgc = HLGCollection(a) nhlgc = NotHLGCollection(a) assert isinstance(hlgc, DaskCollection) assert isinstance(nhlgc, DaskCollection) assert isinstance(nhlgc, DaskCollection) assert not isinstance(nhlgc, HLGDaskCollection) def compute(coll: DaskCollection) -> Any: return coll.compute() def compute2(coll: DaskCollection) -> Any: return coll.compute() def test_parameter_passing() -> None: from dask.array import Array a: Delayed = increment(2) hlgc = HLGCollection(a) assert compute(hlgc) == 3 assert compute2(hlgc) == 3 d: Delayed = increment(3) assert compute(d) == 4 assert compute2(d) == 4 array: Array = da.ones(10) assert compute(array).shape == (10,) assert compute2(array).shape == (10,) def test_inheriting_class() -> None: inheriting: Inheriting = Inheriting(increment(2)) assert isinstance(inheriting, Inheriting)