Last active
January 26, 2023 08:56
-
-
Save AlenkaF/95fb41f461fb792396bb20dd502b4112 to your computer and use it in GitHub Desktop.
Example of tensor extension with tests in PyArrow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import json | |
import math | |
import numpy as np | |
import pyarrow as pa | |
class TensorType(pa.ExtensionType): | |
def __init__(self, value_type, shape, order): | |
self._value_type = value_type | |
self._shape = shape | |
self._order = order | |
size = math.prod(shape) | |
pa.ExtensionType.__init__(self, pa.list_(self._value_type, size), | |
'arrow.tensor') | |
@property | |
def dtype(self): | |
return self._value_type | |
@property | |
def shape(self): | |
return self._shape | |
@property | |
def order(self): | |
""" | |
Order of sorting, can be row or column ('C', 'F') | |
""" | |
return self._order | |
def __arrow_ext_serialize__(self): | |
metadata = {"shape": str(self._shape), | |
"order": self._order} | |
return json.dumps(metadata).encode() | |
@classmethod | |
def __arrow_ext_deserialize__(cls, storage_type, serialized): | |
# return an instance of this subclass given the serialized | |
# metadata. | |
assert serialized.decode().startswith('{"shape":') | |
metadata = json.loads(serialized.decode()) | |
shape = ast.literal_eval(metadata['shape']) | |
order = metadata["order"] | |
return TensorType(storage_type.value_type, shape, order) | |
def __arrow_ext_class__(self): | |
return TensorArray | |
class TensorArray(pa.ExtensionArray): | |
def to_numpy_tensor(self): | |
flat_array = self.storage.flatten().to_numpy() | |
return flat_array.reshape((self.type.shape), | |
order=self.type.order) | |
def from_numpy_tensor(obj): | |
numpy_type = obj.flatten().dtype | |
arrow_type = pa.from_numpy_dtype(numpy_type) | |
order = 'F' if np.isfortran(obj) else 'C' | |
size = obj.size | |
return pa.ExtensionArray.from_storage( | |
TensorType(arrow_type, obj.shape, order), | |
pa.array([obj.flatten()], pa.list_(arrow_type, size)) | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pyarrow as pa | |
import pytest | |
@pytest.fixture | |
def registered_tensor_type(): | |
# setup | |
tensor_type = TensorType(pa.int8(), (2, 2, 3), 'C') | |
tensor_class = tensor_type.__arrow_ext_class__() | |
pa.register_extension_type(tensor_type) | |
yield tensor_type, tensor_class | |
# teardown | |
try: | |
pa.unregister_extension_type('arrow.tensor') | |
except KeyError: | |
pass | |
def test_generic_ext_type(): | |
tensor_type = TensorType(pa.int8(), (2,3), 'C') | |
assert tensor_type.extension_name == "arrow.tensor" | |
assert tensor_type.storage_type == pa.list_(pa.int8(), 6) | |
def test_tensor_ext_class_methods(): | |
tensor_type = TensorType(pa.float32(), (2, 2, 3), 'C') | |
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.float32(), 12)) | |
arr = pa.ExtensionArray.from_storage(tensor_type, storage) | |
expected = np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=np.float32) | |
result = arr.to_numpy_tensor() | |
np.testing.assert_array_equal(result, expected) | |
tensor_array_from_numpy = TensorArray.from_numpy_tensor(expected) | |
assert isinstance(tensor_array_from_numpy.type, TensorType) | |
assert tensor_array_from_numpy.type.dtype == pa.float32() | |
assert tensor_array_from_numpy.type.shape == (2, 2, 3) | |
assert tensor_array_from_numpy.type.order == 'C' | |
def ipc_write_batch(batch): | |
stream = pa.BufferOutputStream() | |
writer = pa.RecordBatchStreamWriter(stream, batch.schema) | |
writer.write_batch(batch) | |
writer.close() | |
return stream.getvalue() | |
def ipc_read_batch(buf): | |
reader = pa.RecordBatchStreamReader(buf) | |
return reader.read_next_batch() | |
def test_generic_ext_type_ipc(registered_tensor_type): | |
tensor_type, tensor_class = registered_tensor_type | |
storage = pa.array([[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]], pa.list_(pa.int8(), 12)) | |
arr = pa.ExtensionArray.from_storage(tensor_type, storage) | |
batch = pa.RecordBatch.from_arrays([arr], ["ext"]) | |
# check the built array has exactly the expected clss | |
assert type(arr) == tensor_class | |
buf = ipc_write_batch(batch) | |
del batch | |
batch = ipc_read_batch(buf) | |
result = batch.column(0) | |
# check the deserialized array class is the expected one | |
assert type(result) == tensor_class | |
assert result.type.extension_name == "arrow.tensor" | |
assert arr.storage.to_pylist() == [[1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]] | |
# we get back an actual TensorType | |
assert isinstance(result.type, TensorType) | |
assert result.type.dtype == pa.int8() | |
assert result.type.shape == (2, 2, 3) | |
assert result.type.order == 'C' | |
# using different parametrization as how it was registered | |
tensor_type_uint = tensor_type.__class__(pa.uint8(), (2, 3), 'C') | |
assert tensor_type_uint.extension_name == "arrow.tensor" | |
assert tensor_type_uint.dtype == pa.uint8() | |
assert tensor_type_uint.shape == (2, 3) | |
assert tensor_type_uint.order == 'C' | |
storage = pa.array([[1, 2, 3, 4, 5, 6]], pa.list_(pa.uint8(), 6)) | |
arr = pa.ExtensionArray.from_storage(tensor_type_uint, storage) | |
batch = pa.RecordBatch.from_arrays([arr], ["ext"]) | |
buf = ipc_write_batch(batch) | |
del batch | |
batch = ipc_read_batch(buf) | |
result = batch.column(0) | |
assert isinstance(result.type, TensorType) | |
assert result.type.dtype == pa.uint8() | |
assert result.type.shape == (2, 3) | |
assert result.type.order == 'C' | |
assert type(result) == tensor_class | |
# def test_generic_ext_type_ipc_unknown(registered_tensor_type): | |
# def test_generic_ext_type_equality(): | |
# def test_generic_ext_type_register(registered_tensor_type): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment