import contextlib import numbers from itertools import chain, product from numbers import Integral from operator import getitem import numpy as np from dask.array.backends import array_creation_dispatch from dask.array.core import ( Array, asarray, broadcast_shapes, broadcast_to, normalize_chunks, slices_from_chunks, ) from dask.array.creation import arange from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph from dask.utils import derived_from, random_state_data class RandomState: """ Mersenne Twister pseudo-random number generator This object contains state to deterministically generate pseudo-random numbers from a variety of probability distributions. It is identical to ``np.random.RandomState`` except that all functions also take a ``chunks=`` keyword argument. Parameters ---------- seed: Number Object to pass to RandomState to serve as deterministic seed RandomState: Callable[seed] -> RandomState A callable that, when provided with a ``seed`` keyword provides an object that operates identically to ``np.random.RandomState`` (the default). This might also be a function that returns a ``randomgen.RandomState``, ``mkl_random``, or ``cupy.random.RandomState`` object. Examples -------- >>> import dask.array as da >>> state = da.random.RandomState(1234) # a seed >>> x = state.normal(10, 0.1, size=3, chunks=(2,)) >>> x.compute() array([10.01867852, 10.04812289, 9.89649746]) See Also -------- np.random.RandomState """ def __init__(self, seed=None, RandomState=None): self._numpy_state = np.random.RandomState(seed) if RandomState is None: self._RandomState = array_creation_dispatch.RandomState else: self._RandomState = RandomState def seed(self, seed=None): self._numpy_state.seed(seed) def _wrap( self, funcname, *args, size=None, chunks="auto", extra_chunks=(), **kwargs ): """Wrap numpy random function to produce dask.array random function extra_chunks should be a chunks tuple to append to the end of chunks """ if size is not None and not isinstance(size, (tuple, list)): size = (size,) shapes = list( { ar.shape for ar in chain(args, kwargs.values()) if isinstance(ar, (Array, np.ndarray)) } ) if size is not None: shapes.append(size) # broadcast to the final size(shape) size = broadcast_shapes(*shapes) chunks = normalize_chunks( chunks, size, # ideally would use dtype here dtype=kwargs.get("dtype", np.float64), ) slices = slices_from_chunks(chunks) def _broadcast_any(ar, shape, chunks): if isinstance(ar, Array): return broadcast_to(ar, shape).rechunk(chunks) if isinstance(ar, np.ndarray): return np.ascontiguousarray(np.broadcast_to(ar, shape)) # Broadcast all arguments, get tiny versions as well # Start adding the relevant bits to the graph dsk = {} lookup = {} small_args = [] dependencies = [] for i, ar in enumerate(args): if isinstance(ar, (np.ndarray, Array)): res = _broadcast_any(ar, size, chunks) if isinstance(res, Array): dependencies.append(res) lookup[i] = res.name elif isinstance(res, np.ndarray): name = f"array-{tokenize(res)}" lookup[i] = name dsk[name] = res small_args.append(ar[tuple(0 for _ in ar.shape)]) else: small_args.append(ar) small_kwargs = {} for key, ar in kwargs.items(): if isinstance(ar, (np.ndarray, Array)): res = _broadcast_any(ar, size, chunks) if isinstance(res, Array): dependencies.append(res) lookup[key] = res.name elif isinstance(res, np.ndarray): name = f"array-{tokenize(res)}" lookup[key] = name dsk[name] = res small_kwargs[key] = ar[tuple(0 for _ in ar.shape)] else: small_kwargs[key] = ar sizes = list(product(*chunks)) seeds = random_state_data(len(sizes), self._numpy_state) token = tokenize(seeds, size, chunks, args, kwargs) name = f"{funcname}-{token}" keys = product( [name], *([range(len(bd)) for bd in chunks] + [[0]] * len(extra_chunks)) ) blocks = product(*[range(len(bd)) for bd in chunks]) vals = [] for seed, size, slc, block in zip(seeds, sizes, slices, blocks): arg = [] for i, ar in enumerate(args): if i not in lookup: arg.append(ar) else: if isinstance(ar, Array): arg.append((lookup[i],) + block) else: # np.ndarray arg.append((getitem, lookup[i], slc)) kwrg = {} for k, ar in kwargs.items(): if k not in lookup: kwrg[k] = ar else: if isinstance(ar, Array): kwrg[k] = (lookup[k],) + block else: # np.ndarray kwrg[k] = (getitem, lookup[k], slc) vals.append( (_apply_random, self._RandomState, funcname, seed, size, arg, kwrg) ) meta = _apply_random( self._RandomState, funcname, seed, (0,) * len(size), small_args, small_kwargs, ) dsk.update(dict(zip(keys, vals))) graph = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) return Array(graph, name, chunks + extra_chunks, meta=meta) @derived_from(np.random.RandomState, skipblocks=1) def beta(self, a, b, size=None, chunks="auto", **kwargs): return self._wrap("beta", a, b, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def binomial(self, n, p, size=None, chunks="auto", **kwargs): return self._wrap("binomial", n, p, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def chisquare(self, df, size=None, chunks="auto", **kwargs): return self._wrap("chisquare", df, size=size, chunks=chunks, **kwargs) with contextlib.suppress(AttributeError): @derived_from(np.random.RandomState, skipblocks=1) def choice(self, a, size=None, replace=True, p=None, chunks="auto"): dependencies = [] # Normalize and validate `a` if isinstance(a, Integral): # On windows the output dtype differs if p is provided or # absent, see https://github.com/numpy/numpy/issues/9867 dummy_p = np.array([1]) if p is not None else p dtype = np.random.choice(1, size=(), p=dummy_p).dtype len_a = a if a < 0: raise ValueError("a must be greater than 0") else: a = asarray(a) a = a.rechunk(a.shape) dtype = a.dtype if a.ndim != 1: raise ValueError("a must be one dimensional") len_a = len(a) dependencies.append(a) a = a.__dask_keys__()[0] # Normalize and validate `p` if p is not None: if not isinstance(p, Array): # If p is not a dask array, first check the sum is close # to 1 before converting. p = np.asarray(p) if not np.isclose(p.sum(), 1, rtol=1e-7, atol=0): raise ValueError("probabilities do not sum to 1") p = asarray(p) else: p = p.rechunk(p.shape) if p.ndim != 1: raise ValueError("p must be one dimensional") if len(p) != len_a: raise ValueError("a and p must have the same size") dependencies.append(p) p = p.__dask_keys__()[0] if size is None: size = () elif not isinstance(size, (tuple, list)): size = (size,) chunks = normalize_chunks(chunks, size, dtype=np.float64) if not replace and len(chunks[0]) > 1: err_msg = ( "replace=False is not currently supported for " "dask.array.choice with multi-chunk output " "arrays" ) raise NotImplementedError(err_msg) sizes = list(product(*chunks)) state_data = random_state_data(len(sizes), self._numpy_state) name = "da.random.choice-%s" % tokenize( state_data, size, chunks, a, replace, p ) keys = product([name], *(range(len(bd)) for bd in chunks)) dsk = { k: (_choice, state, a, size, replace, p) for k, state, size in zip(keys, state_data, sizes) } graph = HighLevelGraph.from_collections( name, dsk, dependencies=dependencies ) return Array(graph, name, chunks, dtype=dtype) # @derived_from(np.random.RandomState, skipblocks=1) # def dirichlet(self, alpha, size=None, chunks="auto"): @derived_from(np.random.RandomState, skipblocks=1) def exponential(self, scale=1.0, size=None, chunks="auto", **kwargs): return self._wrap("exponential", scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def f(self, dfnum, dfden, size=None, chunks="auto", **kwargs): return self._wrap("f", dfnum, dfden, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def gamma(self, shape, scale=1.0, size=None, chunks="auto", **kwargs): return self._wrap("gamma", shape, scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def geometric(self, p, size=None, chunks="auto", **kwargs): return self._wrap("geometric", p, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def gumbel(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): return self._wrap("gumbel", loc, scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def hypergeometric(self, ngood, nbad, nsample, size=None, chunks="auto", **kwargs): return self._wrap( "hypergeometric", ngood, nbad, nsample, size=size, chunks=chunks, **kwargs ) @derived_from(np.random.RandomState, skipblocks=1) def laplace(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): return self._wrap("laplace", loc, scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def logistic(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): return self._wrap("logistic", loc, scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def lognormal(self, mean=0.0, sigma=1.0, size=None, chunks="auto", **kwargs): return self._wrap("lognormal", mean, sigma, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def logseries(self, p, size=None, chunks="auto", **kwargs): return self._wrap("logseries", p, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def multinomial(self, n, pvals, size=None, chunks="auto", **kwargs): return self._wrap( "multinomial", n, pvals, size=size, chunks=chunks, extra_chunks=((len(pvals),),), ) @derived_from(np.random.RandomState, skipblocks=1) def negative_binomial(self, n, p, size=None, chunks="auto", **kwargs): return self._wrap("negative_binomial", n, p, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def noncentral_chisquare(self, df, nonc, size=None, chunks="auto", **kwargs): return self._wrap( "noncentral_chisquare", df, nonc, size=size, chunks=chunks, **kwargs ) @derived_from(np.random.RandomState, skipblocks=1) def noncentral_f(self, dfnum, dfden, nonc, size=None, chunks="auto", **kwargs): return self._wrap( "noncentral_f", dfnum, dfden, nonc, size=size, chunks=chunks, **kwargs ) @derived_from(np.random.RandomState, skipblocks=1) def normal(self, loc=0.0, scale=1.0, size=None, chunks="auto", **kwargs): return self._wrap("normal", loc, scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def pareto(self, a, size=None, chunks="auto", **kwargs): return self._wrap("pareto", a, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def permutation(self, x): from dask.array.slicing import shuffle_slice if isinstance(x, numbers.Number): x = arange(x, chunks="auto") index = np.arange(len(x)) self._numpy_state.shuffle(index) return shuffle_slice(x, index) @derived_from(np.random.RandomState, skipblocks=1) def poisson(self, lam=1.0, size=None, chunks="auto", **kwargs): return self._wrap("poisson", lam, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def power(self, a, size=None, chunks="auto", **kwargs): return self._wrap("power", a, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def randint(self, low, high=None, size=None, chunks="auto", dtype="l", **kwargs): return self._wrap( "randint", low, high, size=size, chunks=chunks, dtype=dtype, **kwargs ) @derived_from(np.random.RandomState, skipblocks=1) def random_integers(self, low, high=None, size=None, chunks="auto", **kwargs): return self._wrap( "random_integers", low, high, size=size, chunks=chunks, **kwargs ) @derived_from(np.random.RandomState, skipblocks=1) def random_sample(self, size=None, chunks="auto", **kwargs): return self._wrap("random_sample", size=size, chunks=chunks, **kwargs) random = random_sample @derived_from(np.random.RandomState, skipblocks=1) def rayleigh(self, scale=1.0, size=None, chunks="auto", **kwargs): return self._wrap("rayleigh", scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def standard_cauchy(self, size=None, chunks="auto", **kwargs): return self._wrap("standard_cauchy", size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def standard_exponential(self, size=None, chunks="auto", **kwargs): return self._wrap("standard_exponential", size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def standard_gamma(self, shape, size=None, chunks="auto", **kwargs): return self._wrap("standard_gamma", shape, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def standard_normal(self, size=None, chunks="auto", **kwargs): return self._wrap("standard_normal", size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def standard_t(self, df, size=None, chunks="auto", **kwargs): return self._wrap("standard_t", df, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def tomaxint(self, size=None, chunks="auto", **kwargs): return self._wrap("tomaxint", size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def triangular(self, left, mode, right, size=None, chunks="auto", **kwargs): return self._wrap( "triangular", left, mode, right, size=size, chunks=chunks, **kwargs ) @derived_from(np.random.RandomState, skipblocks=1) def uniform(self, low=0.0, high=1.0, size=None, chunks="auto", **kwargs): return self._wrap("uniform", low, high, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def vonmises(self, mu, kappa, size=None, chunks="auto", **kwargs): return self._wrap("vonmises", mu, kappa, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def wald(self, mean, scale, size=None, chunks="auto", **kwargs): return self._wrap("wald", mean, scale, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def weibull(self, a, size=None, chunks="auto", **kwargs): return self._wrap("weibull", a, size=size, chunks=chunks, **kwargs) @derived_from(np.random.RandomState, skipblocks=1) def zipf(self, a, size=None, chunks="auto", **kwargs): return self._wrap("zipf", a, size=size, chunks=chunks, **kwargs) def _choice(state_data, a, size, replace, p): state = np.random.RandomState(state_data) return state.choice(a, size=size, replace=replace, p=p) def _apply_random(RandomState, funcname, state_data, size, args, kwargs): """Apply RandomState method with seed""" if RandomState is None: RandomState = np.random.RandomState state = RandomState(state_data) func = getattr(state, funcname) return func(*args, size=size, **kwargs) _cached_random_states = {} def _make_api(attr): def wrapper(*args, **kwargs): backend = array_creation_dispatch.backend if backend not in _cached_random_states: # Cache the default RandomState object for this backend _cached_random_states[backend] = RandomState() return getattr( _cached_random_states[backend], attr, )(*args, **kwargs) wrapper.__name__ = getattr(RandomState, attr).__name__ wrapper.__doc__ = getattr(RandomState, attr).__doc__ return wrapper seed = _make_api("seed") beta = _make_api("beta") binomial = _make_api("binomial") chisquare = _make_api("chisquare") if hasattr(RandomState(), "choice"): choice = _make_api("choice") exponential = _make_api("exponential") f = _make_api("f") gamma = _make_api("gamma") geometric = _make_api("geometric") gumbel = _make_api("gumbel") hypergeometric = _make_api("hypergeometric") laplace = _make_api("laplace") logistic = _make_api("logistic") lognormal = _make_api("lognormal") logseries = _make_api("logseries") multinomial = _make_api("multinomial") negative_binomial = _make_api("negative_binomial") noncentral_chisquare = _make_api("noncentral_chisquare") noncentral_f = _make_api("noncentral_f") normal = _make_api("normal") pareto = _make_api("pareto") permutation = _make_api("permutation") poisson = _make_api("poisson") power = _make_api("power") rayleigh = _make_api("rayleigh") random_sample = _make_api("random_sample") random = random_sample randint = _make_api("randint") random_integers = _make_api("random_integers") triangular = _make_api("triangular") uniform = _make_api("uniform") vonmises = _make_api("vonmises") wald = _make_api("wald") weibull = _make_api("weibull") zipf = _make_api("zipf") """ Standard distributions """ standard_cauchy = _make_api("standard_cauchy") standard_exponential = _make_api("standard_exponential") standard_gamma = _make_api("standard_gamma") standard_normal = _make_api("standard_normal") standard_t = _make_api("standard_t")