import numpy as np from ... import defines from .base import NumpyColumn class NumpyStringColumn(NumpyColumn): null_value = '' default_encoding = defines.STRINGS_ENCODING def __init__(self, encoding=default_encoding, **kwargs): self.encoding = encoding super(NumpyStringColumn, self).__init__(**kwargs) def read_items(self, n_items, buf): return np.array( buf.read_strings(n_items, encoding=self.encoding), dtype=self.dtype ) def write_items(self, items, buf): return buf.write_strings(items.tolist(), encoding=self.encoding) class NumpyByteStringColumn(NumpyColumn): null_value = b'' def read_items(self, n_items, buf): return np.array(buf.read_strings(n_items), dtype=self.dtype) def write_items(self, items, buf): return buf.write_strings(items.tolist()) class NumpyFixedString(NumpyStringColumn): def __init__(self, length, **kwargs): self.length = length super(NumpyFixedString, self).__init__(**kwargs) def read_items(self, n_items, buf): return np.array(buf.read_fixed_strings( n_items, self.length, encoding=self.encoding ), dtype=self.dtype) def write_items(self, items, buf): return buf.write_fixed_strings( items.tolist(), self.length, encoding=self.encoding ) class NumpyByteFixedString(NumpyByteStringColumn): def __init__(self, length, **kwargs): self.length = length super(NumpyByteFixedString, self).__init__(**kwargs) def read_items(self, n_items, buf): return np.array( buf.read_fixed_strings(n_items, self.length), dtype=self.dtype ) def write_items(self, items, buf): return buf.write_fixed_strings(items.tolist(), self.length) def create_string_column(spec, column_options): client_settings = column_options['context'].client_settings strings_as_bytes = client_settings['strings_as_bytes'] encoding = client_settings.get( 'strings_encoding', NumpyStringColumn.default_encoding ) if spec == 'String': cls = NumpyByteStringColumn if strings_as_bytes else NumpyStringColumn return cls(encoding=encoding, **column_options) else: length = int(spec[12:-1]) cls = NumpyByteFixedString if strings_as_bytes else NumpyFixedString return cls(length, encoding=encoding, **column_options)