from __future__ import annotations import logging import msgpack import dask.config from distributed.protocol import pickle from distributed.protocol.compression import decompress, maybe_compress from distributed.protocol.serialize import ( Pickled, Serialize, Serialized, ToPickle, merge_and_deserialize, msgpack_decode_default, msgpack_encode_default, serialize_and_split, ) from distributed.protocol.utils import msgpack_opts from distributed.utils import ensure_memoryview logger = logging.getLogger(__name__) def dumps( # type: ignore[no-untyped-def] msg, serializers=None, on_error="message", context=None, frame_split_size=None ) -> list: """Transform Python message to bytestream suitable for communication Developer Notes --------------- The approach here is to use `msgpack.dumps()` to serialize `msg` and write the result to the first output frame. If `msgpack.dumps()` encounters an object it cannot serialize like a NumPy array, it is handled out-of-band by `_encode_default()` and appended to the output frame list. """ try: if context and "compression" in context: compress_opts = {"compression": context["compression"]} else: compress_opts = {} def _inplace_compress_frames(header, frames): compression = list(header.get("compression", [None] * len(frames))) for i in range(len(frames)): if compression[i] is None: compression[i], frames[i] = maybe_compress( frames[i], **compress_opts ) header["compression"] = tuple(compression) def create_serialized_sub_frames(obj: Serialized | Serialize) -> list: if isinstance(obj, Serialized): sub_header, sub_frames = obj.header, obj.frames else: sub_header, sub_frames = serialize_and_split( obj, serializers=serializers, on_error=on_error, context=context, size=frame_split_size, ) _inplace_compress_frames(sub_header, sub_frames) sub_header["num-sub-frames"] = len(sub_frames) sub_header = msgpack.dumps( sub_header, default=msgpack_encode_default, use_bin_type=True ) return [sub_header] + sub_frames def create_pickled_sub_frames(obj: Pickled | ToPickle) -> list: if isinstance(obj, Pickled): sub_header, sub_frames = obj.header, obj.frames else: sub_frames = [] sub_header = { "pickled-obj": pickle.dumps( obj.data, # In to support len() and slicing, we convert `PickleBuffer` # objects to memoryviews of bytes. buffer_callback=lambda x: sub_frames.append( ensure_memoryview(x) ), ) } _inplace_compress_frames(sub_header, sub_frames) sub_header["num-sub-frames"] = len(sub_frames) sub_header = msgpack.dumps(sub_header) return [sub_header] + sub_frames frames = [None] def _encode_default(obj): if isinstance(obj, (Serialize, Serialized)): offset = len(frames) frames.extend(create_serialized_sub_frames(obj)) return {"__Serialized__": offset} elif isinstance(obj, (ToPickle, Pickled)): offset = len(frames) frames.extend(create_pickled_sub_frames(obj)) return {"__Pickled__": offset} else: return msgpack_encode_default(obj) frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True) return frames except Exception: logger.critical("Failed to Serialize", exc_info=True) raise def loads(frames, deserialize=True, deserializers=None): """Transform bytestream back into Python value""" allow_pickle = dask.config.get("distributed.scheduler.pickle") try: def _decode_default(obj): offset = obj.get("__Serialized__", 0) if offset > 0: sub_header = msgpack.loads( frames[offset], object_hook=msgpack_decode_default, use_list=False, **msgpack_opts, ) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] if deserialize: if "compression" in sub_header: sub_frames = decompress(sub_header, sub_frames) return merge_and_deserialize( sub_header, sub_frames, deserializers=deserializers ) else: return Serialized(sub_header, sub_frames) offset = obj.get("__Pickled__", 0) if offset > 0: sub_header = msgpack.loads(frames[offset]) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] if allow_pickle: return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) else: raise ValueError( "Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`" ) return msgpack_decode_default(obj) return msgpack.loads( frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts ) except Exception: logger.critical("Failed to deserialize", exc_info=True) raise