Last active August 7, 2021 11:16
Avro Serializer/Deserializer with Numpy Support
from ast import literal_eval
from import DatumReader, DatumWriter, BinaryEncoder, BinaryDecoder
from avro.schema import Names, SchemaFromJSONData
import yaml
import numpy as np
class BinaryDatumWriter(object):
def __init__(self, schema, buf):
if isinstance(schema, (dict, str)):
schema = load_schema(schema)
self.schema = schema
self._writer = DatumWriter(schema)
self._encoder = BinaryEncoder(buf)
self.buf = buf
def write(self, datum):
for (key, value) in datum.items():
if isinstance(value, np.ndarray):
field = self.schema.field_map[key]
if field.type.props['type'] in ['fixed', 'binary']:
datum[key] = value.tobytes()
datum[key] = value.tolist()
self._writer.write(datum, self._encoder)
class BinaryDatumReader(object):
def __init__(self, writer_schema, buf, reader_schema=None):
if isinstance(writer_schema, (dict, str)):
writer_schema = load_schema(writer_schema)
if isinstance(reader_schema, (dict, str)):
reader_schema = load_schema(reader_schema)
self.schema = reader_schema or writer_schema
self._reader = DatumReader(writer_schema, reader_schema)
self._decoder = BinaryDecoder(buf)
self.buf = buf
def read(self):
def unpack_datum(schema, datum):
if isinstance(schema, (dict, str)):
schema = load_schema(schema)
for field in schema.fields:
props = field.type.props
if props.get('logicalType', None) == 'ndarray':
if not props['type'] in ['fixed', 'binary']:
dtype = np.dtype(props['dtype'])
name =
datum[name] = np.frombuffer(datum[name], dtype=dtype)
if 'shape' in props:
shape = literal_eval(props['shape'])
except SyntaxError:
msg = 'Could not parse shape: "%s"' % props['shape']
raise SyntaxError(msg)
datum[name] = datum[name].reshape(shape)
return datum
def load_schema(schema):
if not isinstance(schema, dict):
schema = yaml.load(schema)
return SchemaFromJSONData(schema, Names())
if __name__ == '__main__':
import time
import io
schema = {"namespace": "example.avro",
"type": "record",
"name": "User",
"fields": [
{"name": "name", "type": "string"},
{"name": "favorite_number", "type": ["int", "null"]},
{"name": "favorite_color", "type": ["string", "null"]},
{"name": "yo",
"type": {"type": "array",
"items": ["null", "string"]}},
{"name": "fixed_big",
"type": {"type": "fixed", "name": "image",
"size": 1920 * 1024,
"logicalType": "ndarray", "dtype": "uint8",
"shape": "(1920,1024)"}},
buf = io.BytesIO()
bdata = np.ones((1920, 1024), dtype=np.uint8)
writer = BinaryDatumWriter(schema, buf)
{"name": "Alyssa", "yo": [None, "hey"],
"favorite_number": 256, "fixed_big": bdata})
t0 = time.time()
for i in range(1000):
writer.write({"name": "Ben", "favorite_number": 7,
"yo": [None],
"favorite_color": "red", "fixed_big": bdata})
print('Write to buf:', time.time() - t0)
with open('test.dat', 'wb') as fid:
print('Write to disk:', time.time() - t0)
reader = BinaryDatumReader(schema, buf)
d1 =
for i in range(1000):
print('Read from buf:', time.time() - t0)
print(unpack_datum(schema, d1))
import sys
import sqlite3
import cPickle as pickle
except ImportError:
import pickle
import numpy as np
PY3 = sys.version.startswith('3')
if PY3:
long = int
unicode = str
NULL. The value is a NULL value.
INTEGER. The value is a signed integer, stored in 1, 2, 3, 4, 6, or 8 bytes depending on the magnitude of the value.
REAL. The value is a floating point value, stored as an 8-byte IEEE floating point number.
TEXT. The value is a text string, stored using the database encoding (UTF-8, UTF-16BE or UTF-16LE).
BLOB. The value is a blob of data, stored exactly as it was input.
def get_dtype(value):
if value is None:
return 'NULL'
elif isinstance(value, (int, long)):
return 'INTEGER'
elif isinstance(value, float):
return 'REAL'
elif isinstance(value, bytes):
return 'BLOB'
return 'TEXT'
def convert(value):
if value is None:
elif isinstance(value, (int, long, float, bytes, unicode)):
return value
return pickle.dumps(value, protocol=-1)
class SqlHandler(object):
"""Manage nested data reading and writing from dictionaries to SQL
def __init__(self, db):
if isinstance(db, str):
db = sqlite3.connect(db)
self.db = db
self.cursor = self.db.cursor()
self.tables = self.list_tables()
self.field_names = dict()
def create_table(self, datum, name, primary_key,
"""Create a heirarchal table in a DB from a dictionary
*datum* dictionary
*name* name of the top level table
*primary_key* which key in the dictionary to use as primary
*foreign* Used internally by sub-tables to indicate linkage
lines = ['create table %s\n(%s %s primary key' %
(name, primary_key, get_dtype(primary_key))]
for (key, value) in datum.items():
if key == primary_key:
if not isinstance(value, dict):
lines.append('%s %s' % (key, get_dtype(value)))
if foreign:
lines.append('foreign key (%s) references %s(%s)'
% (primary_key, name.split('__')[0], primary_key))
if not name in self.tables:
self.cursor.execute(',\n'.join(lines) + ')')
self.field_names[name] = self.list_fields(name)
for (key, value) in datum.items():
if isinstance(value, dict):
value[primary_key] = datum[primary_key]
self.create_table(value, '%s__%s' % (name, key),
primary_key, foreign=True)
def remove_table(self, name, cmd='drop'):
"""Remove a table and its children"""
self.cursor.execute('%s table %s' % (cmd, name))
for table in self.list_tables():
if table.startswith(name + '__'):
self.cursor.execute('%s table %s' % (cmd, table))
def clear_table(self, name):
"""Clear a table and its children"""
self.remove_table(name, cmd='truncate')
def add_data(self, datum, tbl_name):
"""Add datum to a given table heirarchally.
*datum* a dictionary
*tbl_name* name of a valid, existing table
Assumes the data is in the right format for the table.
if not tbl_name in self.field_names:
field_names = self.list_fields(tbl_name)[0]
self.field_names[tbl_name] = field_names
primary_key = self.field_names[tbl_name][0]
keys = list(datum.keys())
for (key, value) in datum.items():
if isinstance(value, dict):
query = "insert into {0} ({1}) values (?{2})"
query = query.format(tbl_name, ",".join(keys), ",?" * (len(keys) - 1))
self.cursor.execute(query, [convert(datum[k]) for k in keys])
for (key, value) in datum.items():
if isinstance(value, dict):
value[primary_key] = datum[primary_key]
self.add_data(value, '%s__%s' % (tbl_name, key))
def read_table(self, tbl_name, which='*', criteria=None):
"""Read data from a heirarchal table as a dictionary.
*tbl_name* is the top level table in the heirarchy
*which* can be a list of field names or a csv string of names
*criteria* is a valid SQL WHERE statement
*tbl_name* and *which* can be a path to nested data
in /path/format.
*criteria* can include AND or OR and use any valid Comparision
or Logical Operators like >, <, =, LIKE, NOT, etc.
if tbl_name.startswith('/'):
tbl_name = tbl_name[1:]
tbl_name = tbl_name.replace('/', '__')
if not tbl_name in self.field_names:
self.field_names[tbl_name] = self.list_fields(tbl_name)
fields = self.field_names[tbl_name]
if isinstance(which, list):
which = ','.join(which)
which = which.split(',')
if which == ['*']:
qwhich = fields
qwhich = [w for w in which if w in fields]
if qwhich:
query = 'SELECT %s from %s' % (','.join(qwhich), tbl_name)
if criteria:
query += ' where %s' % criteria.replace(' ', '')
values = self.cursor.fetchall()
data = [dict(zip(qwhich, v)) for v in values]
data = None
for table in self.tables:
if (table.startswith(tbl_name + '__')
and table.count('__') == tbl_name.count('__') + 1):
data = self._read_subtable(table, which, criteria, data)
return data
def _read_subtable(self, table, which, criteria, data):
"""Retreive data from a child table"""
key = table.rpartition('__')[-1]
sub_which = []
for s in which:
if s.startswith('/'):
s = s[1:]
if '/' in s:
s = s.partition('/')[2]
if key in sub_which:
sub_data = self.read_table(table, '*', criteria)
sub_data = self.read_table(table, sub_which, criteria)
if not sub_data:
return data
if not data:
data = [{key: d} for d in sub_data]
for (datum, sub_datum) in zip(data, sub_data):
datum[key] = sub_datum
return data
def list_tables(self):
"""Get a list of the table names in the DB"""
query = 'SELECT name FROM sqlite_master WHERE type = "table"'
return [t[0] for t in self.cursor.fetchall()]
def list_fields(self, tbl_name):
"""Get a list of the field names for a given table name"""
self.cursor.execute("pragma table_info(%s)" % tbl_name)
return [r[1] for r in self.cursor.fetchall()]
def unpack(datum):
for (key, value) in datum.items():
if isinstance(value, dict):
datum[key] = unpack(value)
elif isinstance(value, bytes):
datum[key] = pickle.loads(value)
except Exception:
return datum
if __name__ == '__main__':
import os
if os.path.exists('test.sqlite'):
db = sqlite3.connect('test.sqlite')
bdata = np.ones((10, 10), dtype=np.uint8)
data = {"name": "Alyssa", "yo": [None, "hey"],
"favorite_number": 256, "fixed_big": bdata,
'inner': dict(b=10, inner2=dict(foo=3))}
sh = SqlHandler(db)
sh.create_table(data, 'test2', 'name')
sh.add_data(data, 'test2')
data['name'] = 'Bob'
data['favorite_number'] = 'forty'
sh.add_data(data, 'test2')
print(sh.read_table('test2', which='favorite_number'))
print(sh.read_table('test2', which='inner'))
print(sh.read_table('test2', which='/inner/b'))
print(sh.read_table('test2', which='/inner/inner2'))
print(sh.read_table('test2', which='inner/inner2/foo'))
bob = sh.read_table('test2', criteria='name="Bob"')[0]
print('Bob:', unpack(bob))
