"""Encapsulated manager for in-memory tasks on a worker. This module covers: - spill/unspill data depending on the 'distributed.worker.memory.target' threshold - spill/unspill data depending on the 'distributed.worker.memory.spill' threshold - pause/unpause the worker depending on the 'distributed.worker.memory.pause' threshold - kill the worker depending on the 'distributed.worker.memory.terminate' threshold This module does *not* cover: - Changes in behaviour in Worker, Scheduler, task stealing, Active Memory Manager, etc. caused by the Worker being in paused status - Worker restart after it's been killed - Scheduler-side heuristics regarding memory usage, e.g. the Active Memory Manager See also: - :mod:`distributed.spill`, which implements the spill-to-disk mechanism and is wrapped by this module. Unlike this module, :mod:`distributed.spill` is agnostic to the Worker. - :mod:`distributed.active_memory_manager`, which runs on the scheduler side """ from __future__ import annotations import asyncio import logging import os import sys import warnings from collections.abc import Callable, MutableMapping from contextlib import suppress from functools import partial from typing import TYPE_CHECKING, Any, Container, Literal, cast import psutil import dask.config from dask.system import CPU_COUNT from dask.utils import format_bytes, parse_bytes, parse_timedelta from distributed import system from distributed.compatibility import WINDOWS, PeriodicCallback from distributed.core import Status from distributed.metrics import monotonic from distributed.spill import ManualEvictProto, SpillBuffer from distributed.utils import has_arg, log_errors from distributed.utils_perf import ThrottledGC if TYPE_CHECKING: # Circular imports from distributed.nanny import Nanny from distributed.worker import Worker class WorkerMemoryManager: """Management of worker memory usage Parameters ---------- worker Worker to manage For meaning of the remaining parameters, see the matching parameter names in :class:`~.distributed.worker.Worker`. Notes ----- If data is a callable and has the argument ``worker_local_directory`` in its signature, it will be filled with the worker's attr:``local_directory``. """ data: MutableMapping[str, object] # {task key: task payload} memory_limit: int | None memory_target_fraction: float | Literal[False] memory_spill_fraction: float | Literal[False] memory_pause_fraction: float | Literal[False] max_spill: int | Literal[False] memory_monitor_interval: float _throttled_gc: ThrottledGC def __init__( self, worker: Worker, *, nthreads: int, memory_limit: str | float = "auto", # This should be None most of the times, short of a power user replacing the # SpillBuffer with their own custom dict-like data: ( MutableMapping[str, Any] # pre-initialised | Callable[[], MutableMapping[str, Any]] # constructor # constructor, passed worker.local_directory | Callable[[str], MutableMapping[str, Any]] | tuple[ Callable[..., MutableMapping[str, Any]], dict[str, Any] ] # (constructor, kwargs to constructor) | None # create internally ) = None, # Deprecated parameters; use dask.config instead memory_target_fraction: float | Literal[False] | None = None, memory_spill_fraction: float | Literal[False] | None = None, memory_pause_fraction: float | Literal[False] | None = None, ): self.logger = logging.getLogger("distributed.worker.memory") self.memory_limit = parse_memory_limit( memory_limit, nthreads, logger=self.logger ) self.memory_target_fraction = _parse_threshold( "distributed.worker.memory.target", "memory_target_fraction", memory_target_fraction, ) self.memory_spill_fraction = _parse_threshold( "distributed.worker.memory.spill", "memory_spill_fraction", memory_spill_fraction, ) self.memory_pause_fraction = _parse_threshold( "distributed.worker.memory.pause", "memory_pause_fraction", memory_pause_fraction, ) max_spill = dask.config.get("distributed.worker.memory.max-spill") self.max_spill = False if max_spill is False else parse_bytes(max_spill) if isinstance(data, MutableMapping): self.data = data elif callable(data): if has_arg(data, "worker_local_directory"): data = cast("Callable[[str], MutableMapping[str, Any]]", data) self.data = data(worker.local_directory) else: data = cast("Callable[[], MutableMapping[str, Any]]", data) self.data = data() elif isinstance(data, tuple): func, kwargs = data if not callable(func): raise ValueError("Expecting a callable") if has_arg(func, "worker_local_directory"): self.data = func( worker_local_directory=worker.local_directory, **kwargs ) else: self.data = func(**kwargs) elif self.memory_limit and ( self.memory_target_fraction or self.memory_spill_fraction ): if self.memory_target_fraction: target = int( self.memory_limit * (self.memory_target_fraction or self.memory_spill_fraction) ) else: target = sys.maxsize self.data = SpillBuffer( os.path.join(worker.local_directory, "storage"), target=target, max_spill=self.max_spill, ) else: self.data = {} self.memory_monitor_interval = parse_timedelta( dask.config.get("distributed.worker.memory.monitor-interval"), default=False, ) assert isinstance(self.memory_monitor_interval, (int, float)) if self.memory_limit and ( self.memory_spill_fraction is not False or self.memory_pause_fraction is not False ): assert self.memory_monitor_interval is not None pc = PeriodicCallback( # Don't store worker as self.worker to avoid creating a circular # dependency. We could have alternatively used a weakref. partial(self.memory_monitor, worker), self.memory_monitor_interval * 1000, ) worker.periodic_callbacks["memory_monitor"] = pc self._throttled_gc = ThrottledGC(logger=self.logger) @log_errors async def memory_monitor(self, worker: Worker) -> None: """Track this process's memory usage and act accordingly. If process memory rises above the spill threshold (70%), start dumping data to disk until it goes below the target threshold (60%). If process memory rises above the pause threshold (80%), stop execution of new tasks. """ # Don't use psutil directly; instead read from the same API that is used # to send info to the Scheduler (e.g. for the benefit of Active Memory # Manager) and which can be easily mocked in unit tests. memory = worker.monitor.get_process_memory() self._maybe_pause_or_unpause(worker, memory) await self._maybe_spill(worker, memory) def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None: if self.memory_pause_fraction is False: return assert self.memory_limit frac = memory / self.memory_limit # Pause worker threads if above 80% memory use if frac > self.memory_pause_fraction: # Try to free some memory while in paused state self._throttled_gc.collect() if worker.status == Status.running: self.logger.warning( "Worker is at %d%% memory usage. Pausing worker. " "Process memory: %s -- Worker memory limit: %s", int(frac * 100), format_bytes(memory), format_bytes(self.memory_limit) if self.memory_limit is not None else "None", ) worker.status = Status.paused elif worker.status == Status.paused: self.logger.warning( "Worker is at %d%% memory usage. Resuming worker. " "Process memory: %s -- Worker memory limit: %s", int(frac * 100), format_bytes(memory), format_bytes(self.memory_limit) if self.memory_limit is not None else "None", ) worker.status = Status.running async def _maybe_spill(self, worker: Worker, memory: int) -> None: if self.memory_spill_fraction is False: return # SpillBuffer or a duct-type compatible MutableMapping which offers the # fast property and evict() methods. Dask-CUDA uses this. if not hasattr(self.data, "fast") or not hasattr(self.data, "evict"): return data = cast(ManualEvictProto, self.data) assert self.memory_limit frac = memory / self.memory_limit if frac <= self.memory_spill_fraction: return total_spilled = 0 self.logger.debug( "Worker is at %.0f%% memory usage. Start spilling data to disk.", frac * 100, ) # Implement hysteresis cycle where spilling starts at the spill threshold and # stops at the target threshold. Normally that here the target threshold defines # process memory, whereas normally it defines reported managed memory (e.g. # output of sizeof() ). If target=False, disable hysteresis. target = self.memory_limit * ( self.memory_target_fraction or self.memory_spill_fraction ) count = 0 need = memory - target last_checked_for_pause = last_yielded = monotonic() while memory > target: if not data.fast: self.logger.warning( "Unmanaged memory use is high. This may indicate a memory leak " "or the memory may not be released to the OS; see " "https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os " "for more information. " "-- Unmanaged memory: %s -- Worker memory limit: %s", format_bytes(memory), format_bytes(self.memory_limit), ) break weight = data.evict() if weight == -1: # Failed to evict: # disk full, spill size limit exceeded, or pickle error break total_spilled += weight count += 1 memory = worker.monitor.get_process_memory() if total_spilled > need and memory > target: # Issue a GC to ensure that the evicted data is actually # freed from memory and taken into account by the monitor # before trying to evict even more data. self._throttled_gc.collect() memory = worker.monitor.get_process_memory() now = monotonic() # Spilling may potentially take multiple seconds; we may pass the pause # threshold in the meantime. if now - last_checked_for_pause > self.memory_monitor_interval: self._maybe_pause_or_unpause(worker, memory) last_checked_for_pause = now # Increase spilling aggressiveness when the fast buffer is filled with a lot # of small values. This artificially chokes the rest of the event loop - # namely, the reception of new data from other workers. While this is # somewhat of an ugly hack, DO NOT tweak this without a thorough cycle of # stress testing. See: https://github.com/dask/distributed/issues/6110. if now - last_yielded > 0.5: # See metrics: # - disk-load-duration # - get-data-load-duration # - disk-write-target-duration # - disk-write-spill-duration worker.digest_metric("disk-write-spill-duration", now - last_yielded) await asyncio.sleep(0) last_yielded = monotonic() now = monotonic() if now - last_yielded > 0.005: worker.digest_metric("disk-write-spill-duration", now - last_yielded) if count: self.logger.debug( "Moved %d tasks worth %s to disk", count, format_bytes(total_spilled), ) def _to_dict(self, *, exclude: Container[str] = ()) -> dict: info = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} del info["logger"] info["data"] = dict.fromkeys(self.data) return info class NannyMemoryManager: memory_limit: int | None memory_terminate_fraction: float | Literal[False] memory_monitor_interval: float | None _last_terminated_pid: int def __init__( self, nanny: Nanny, *, memory_limit: str | float = "auto", ): self.logger = logging.getLogger("distributed.nanny.memory") self.memory_limit = parse_memory_limit( memory_limit, nanny.nthreads, logger=self.logger ) self.memory_terminate_fraction = dask.config.get( "distributed.worker.memory.terminate" ) self.memory_monitor_interval = parse_timedelta( dask.config.get("distributed.worker.memory.monitor-interval"), default=False, ) assert isinstance(self.memory_monitor_interval, (int, float)) self._last_terminated_pid = -1 if self.memory_limit and self.memory_terminate_fraction is not False: pc = PeriodicCallback( partial(self.memory_monitor, nanny), self.memory_monitor_interval * 1000, ) nanny.periodic_callbacks["memory_monitor"] = pc def memory_monitor(self, nanny: Nanny) -> None: """Track worker's memory. Restart if it goes above terminate fraction.""" if ( nanny.status != Status.running or nanny.process is None or nanny.process.process is None or nanny.process.process.pid is None ): return # pragma: nocover process = nanny.process.process try: memory = psutil.Process(process.pid).memory_info().rss except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): return # pragma: nocover if memory / self.memory_limit <= self.memory_terminate_fraction: return if self._last_terminated_pid != process.pid: self.logger.warning( f"Worker {nanny.worker_address} (pid={process.pid}) exceeded " f"{self.memory_terminate_fraction * 100:.0f}% memory budget. " "Restarting...", ) self._last_terminated_pid = process.pid process.terminate() else: # We already sent SIGTERM to the worker, but the process is still alive # since the previous iteration of the memory_monitor - for example, some # user code may have tampered with signal handlers. # Send SIGKILL for immediate termination. # # Note that this should not be a disk-related issue. Unlike in a regular # worker shutdown, where the worker cleans up its own spill directory, in # case of SIGTERM no atexit or weakref.finalize callback is triggered # whatsoever; instead, the nanny cleans up the spill directory *after* the # worker has been shut down and before starting a new one. # This is important, as spill directory cleanup may potentially take tens of # seconds and, if the worker did it, any task that was running and leaking # would continue to do so for the whole duration of the cleanup, increasing # the risk of going beyond 100%. self.logger.warning( f"Worker {nanny.worker_address} (pid={process.pid}) is slow to %s", # On Windows, kill() is an alias to terminate() "terminate; trying again" if WINDOWS else "accept SIGTERM; sending SIGKILL", ) process.kill() def parse_memory_limit( memory_limit: str | float | None, nthreads: int, total_cores: int = CPU_COUNT, *, logger: logging.Logger, ) -> int | None: if memory_limit is None: return None orig = memory_limit if memory_limit == "auto": memory_limit = int(system.MEMORY_LIMIT * min(1, nthreads / total_cores)) with suppress(ValueError, TypeError): memory_limit = float(memory_limit) if isinstance(memory_limit, float) and memory_limit <= 1: memory_limit = int(memory_limit * system.MEMORY_LIMIT) if isinstance(memory_limit, str): memory_limit = parse_bytes(memory_limit) else: memory_limit = int(memory_limit) assert isinstance(memory_limit, int) if memory_limit == 0: return None if system.MEMORY_LIMIT < memory_limit: logger.warning( "Ignoring provided memory limit %s due to system memory limit of %s", orig, format_bytes(system.MEMORY_LIMIT), ) return system.MEMORY_LIMIT else: return memory_limit def _parse_threshold( config_key: str, deprecated_param_name: str, deprecated_param_value: float | Literal[False] | None, ) -> float | Literal[False]: if deprecated_param_value is not None: warnings.warn( f"Parameter {deprecated_param_name} has been deprecated and will be " f"removed in a future version; please use dask config key {config_key} " "instead", FutureWarning, ) return deprecated_param_value return dask.config.get(config_key) def _warn_deprecated(w: Nanny | Worker, name: str) -> None: warnings.warn( f"The `{type(w).__name__}.{name}` attribute has been moved to " f"`{type(w).__name__}.memory_manager.{name}", FutureWarning, ) class DeprecatedMemoryManagerAttribute: name: str def __set_name__(self, owner: type, name: str) -> None: self.name = name def __get__(self, instance: Nanny | Worker | None, owner: type) -> Any: if instance is None: # This is triggered by Sphinx return None # pragma: nocover _warn_deprecated(instance, self.name) return getattr(instance.memory_manager, self.name) def __set__(self, instance: Nanny | Worker, value: Any) -> None: _warn_deprecated(instance, self.name) setattr(instance.memory_manager, self.name, value) class DeprecatedMemoryMonitor: def __get__(self, instance: Nanny | Worker | None, owner: type) -> Any: if instance is None: # This is triggered by Sphinx return None # pragma: nocover _warn_deprecated(instance, "memory_monitor") return partial(instance.memory_manager.memory_monitor, instance)