Last active
June 7, 2022 15:28
-
-
Save bivald/f8e0a7625af2eabbf7c5fa055da91d61 to your computer and use it in GitHub Desktop.
Convert a parquet file with dictionaries/categorical values into an arrow file row group per row group
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from pyarrow import fs | |
import hashlib | |
import pyarrow as pa | |
import pyarrow.parquet as pq | |
input_file = 'input.parq' | |
output_file = 'data.arrow' | |
parquet_file = pq.ParquetFile(input_file) | |
schema = parquet_file.schema_arrow | |
categories_columns = [] | |
# Figure out what columns are Dictionaries | |
for column in schema: | |
if isinstance(column.type, pa.lib.DictionaryType): | |
categories_columns.append(column.name) | |
dictionary_values = {} | |
# Read the dictionary values | |
for column in categories_columns: | |
print("Reading parquet column", column) | |
df = pd.read_parquet(input_file, columns=[column]) | |
dictionary_values[column] = df[column].cat.categories.tolist() | |
del df # We take them one by one so we never keep the full data in memory | |
local = fs.LocalFileSystem() | |
# Create a DictionaryArray using indices and dictionary values | |
def create_dictionary_array_indices(column_name, arrow_array): | |
global categories_columns | |
values = arrow_array.to_pylist() | |
indices = [] | |
for i, value in enumerate(values): | |
if not value or value != value: | |
indices.append(None) | |
else: | |
indices.append( | |
dictionary_values[column_name].index(value) | |
) | |
indices = pd.array(indices, dtype=pd.Int32Dtype()) | |
return pa.DictionaryArray.from_arrays(indices, dictionary_values[column_name]) | |
i = 0 | |
with local.open_output_stream(output_file) as file: | |
with pa.RecordBatchFileWriter(file, schema, options=pa.ipc.IpcWriteOptions(emit_dictionary_deltas=True)) as writer: # | |
for record_batch in parquet_file.iter_batches(): | |
i += 1 | |
print(f"Batch {i}/{parquet_file.num_row_groups}") | |
columns = [] | |
for column in schema: | |
if column.name not in categories_columns: | |
columns.append(record_batch[column.name]) | |
else: | |
columns.append( | |
create_dictionary_array_indices(column.name, record_batch[column.name]) | |
) | |
writer.write_batch( | |
pa.record_batch( | |
columns, | |
schema=schema | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment