Last active
September 9, 2023 23:56
-
-
Save tomdean/a58f5e7bb1ad87a5af50678d41773669 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import datetime | |
from typing import Iterator, List, Sized, Union | |
import numpy as np | |
import pandas as pd | |
from psycopg2.extensions import QuotedString | |
from sqlalchemy import and_, exists, MetaData, Table, Column as SAColumn | |
import logging | |
log = logging.getLogger() | |
#: Number of rows to insert per batch transaction | |
BATCH_SIZE = 50000 | |
def to_python_type(column): | |
if str(column.type) == 'UUID': | |
return str | |
return column.type.python_type | |
def to_str(val): | |
if isinstance(val, bytes): | |
val = val.decode('utf-8') | |
return QuotedString((str(val) or "").encode('utf-8')).getquoted().decode('utf-8') | |
def to_date(value): | |
dt = pd.to_datetime(value) | |
return to_str(dt.date().isoformat() if dt else None) | |
def to_datetime(value): | |
dt = pd.to_datetime(value) | |
return to_str(dt.to_pydatetime().isoformat() if dt else None) | |
class Column: | |
def __init__(self, name: str, python_type: type): | |
"""Wrapper to cast Python values for use in ad-hoc SQL. | |
Example:: | |
columns = [Column('id', int), Column('amount', float)] | |
:param name: Name of the column. | |
:param python_type: Python type e.g. int, str, float. | |
""" | |
self.name = name | |
self.python_type = python_type | |
def escape(self, value) -> str: | |
"""Escape a value for use in a Postgres ad-hoc SQL statement.""" | |
if pd.isnull(value): | |
return 'NULL' | |
func = self.python_type | |
if isinstance(value, (datetime.datetime, np.datetime64, pd.Timestamp)) or \ | |
func in (datetime.date, datetime.datetime): | |
func = to_datetime | |
elif isinstance(value, datetime.date): | |
print(self.name) | |
func = to_date | |
elif issubclass(self.python_type, str): | |
func = to_str | |
return func(value) | |
def __eq__(self, b): | |
return self.name == b.name and self.python_type == b.python_type | |
def __repr__(self): | |
return '{}<name={}, type={}>'.format( | |
self.__class__.__name__, self.name, self.python_type.__name__) | |
class ColumnCollection(OrderedDict): | |
def __init__(self, columns: list): | |
super().__init__([(c.name, c) for c in columns]) | |
class BulkInsertFromIterator: | |
def __init__(self, table, data: Iterator, columns: list, | |
batch_size: int=BATCH_SIZE, header: bool=False): | |
"""Bulk insert into Postgres from an iterator in fixed-size batches. | |
Example:: | |
bulk = BulkInsertFromIterator( | |
'table.name', | |
iter([[1, 'Python'], [2, 'PyPy', 3]]), | |
[Column('id', int), Column('name', str)] | |
) | |
bulk.execute(db.engine.raw_connection) | |
:param table: Name of the table. | |
:param data: Iterable containing the data to insert. | |
:param columns: List of :class:`Column` objects. | |
:param batch_size: Rows to insert per batch. | |
:param header: True if the first row is a header. | |
""" | |
self.table = table | |
self.data = data | |
self.columns = columns | |
self.batch_size = batch_size | |
self.header = header | |
if isinstance(self.data, list): | |
self.data = iter(self.data) | |
if not isinstance(self.data, Iterator): | |
raise TypeError('Expected Iterator, got {}'.format( | |
self.data.__class__)) | |
if not self.columns: | |
raise ValueError('Columns cannot be empty') | |
if isinstance(self.columns[0], tuple): | |
self.columns = [Column(*c) for c in self.columns] | |
def batch_execute(self, conn): | |
"""Insert data in batches of `batch_size`. | |
:param conn: A DB API 2.0 connection object | |
""" | |
def batches(data, batch_size): | |
"""Return batches of length `batch_size` from any object that | |
supports iteration without knowing length.""" | |
rv = [] | |
for idx, line in enumerate(data): | |
if idx != 0 and idx % batch_size == 0: | |
yield rv | |
rv = [] | |
rv.append(line) | |
yield rv | |
columns = ColumnCollection(self.columns) | |
if self.header: | |
self.columns = [columns.get(h) for h in next(self.data)] | |
columns = ColumnCollection(self.columns) | |
total = 0 | |
query = BulkInsertQuery(self.table, columns) | |
for batch in batches(self.data, self.batch_size): | |
total += query.execute(conn, batch) or 0 | |
yield total | |
def execute(self, conn): | |
"""Execute all batches.""" | |
return max(list(self.batch_execute(conn))) | |
class BulkInsertQuery: | |
def __init__(self, table: str, columns): | |
"""Execute a multi-row INSERT statement. | |
This does not take advantage of parameterized queries, but escapes | |
string values manually in :class:`Column`. | |
:param table: Name of the table being inserted into. | |
:param columns: Columns required for type coercion. | |
""" | |
self.table = table | |
self.columns = columns | |
self.query = 'INSERT INTO {} ({}) VALUES '.format( | |
table, ', '.join([c for c in columns])) | |
def execute(self, conn, rows: list) -> int: | |
"""Execute a single multi-row INSERT for `rows`. | |
:param conn: Function that returns a database connection | |
:param rows: List of tuples in the same order as :attr:`columns`. | |
""" | |
if not len(rows): | |
raise ValueError('No data provided') | |
if len(self.columns) != len(rows[0]): | |
raise ValueError('Expecting {} columns, found {}'.format( | |
len(self.columns), len(rows[0]))) | |
# Clone the data | |
rows = list(rows) | |
conn = conn() | |
cursor = conn.cursor() | |
try: | |
cursor.execute(self.query + ', '.join(self.escape_rows(rows))) | |
conn.commit() | |
finally: | |
cursor.close() | |
conn.close() | |
return len(rows) | |
def escape_rows(self, rows: list): | |
"""Escape values for use in non-parameterized SQL queries. | |
:param rows: List of values to escape. | |
""" | |
def to_tuple(values): | |
rv = [] | |
for column in self.columns: | |
rv.append(self.columns.get(column).escape(values[column])) | |
return tuple(rv) | |
for idx, row in enumerate(rows): | |
data = to_tuple(row) | |
rows[idx] = '({})'.format(', '.join(map(str, data))) | |
return rows | |
def as_columns(columns) -> List[Column]: | |
rv = [] | |
for column in columns: | |
if isinstance(column, Column): | |
rv.append(column) | |
if isinstance(column, tuple): | |
rv.append(Column(*column)) | |
if isinstance(column, str): | |
rv.append(Column(column, str)) | |
if isinstance(column, SAColumn): | |
rv.append(Column(column.name, to_python_type(column))) | |
return rv | |
def from_sqlalchemy_table(table: Table, data: Iterator, columns: List[str], | |
batch_size: int=BATCH_SIZE) -> BulkInsertFromIterator: | |
"""Return a :class:`BulkInsertFromIterator` based on the metadata | |
of a SQLAlchemy table. | |
Example:: | |
batch = from_sqlalchemy_table( | |
Rating.__table__, | |
data, | |
['rating_id', 'repo_id', 'login_id', 'rating'] | |
) | |
:param table: A :class:`sqlalchemy.Table` instance. | |
:param data: An iterator. | |
:param columns: List of column names to use. | |
:param batch_size: Number of rows to insert per SQL statement | |
""" | |
if not isinstance(table, Table): | |
raise TypeError('Expected sqlalchemy.Table, got {}'.format(table)) | |
wrapped = [] | |
for name in columns: | |
column = table.columns.get(name) | |
wrapped.append(Column(str(column.name), to_python_type(column))) | |
return BulkInsertFromIterator(table, data, wrapped, batch_size, False) | |
def create_staging_table(engine, table: Table) -> Table: | |
"""Create a copy of the table to store intermediary results. | |
Primary keys and other unique constraints are removed. | |
:param engine: SQLAlchemy engine | |
:param table: SQLAlchemy table to clone schema from | |
""" | |
table = table.tometadata(MetaData(), schema="staging") | |
# Remove constraints to prevent errors | |
for column in table.columns: | |
if column.primary_key: | |
column.primary_key = False | |
table.indexes = [] | |
table.constraints = [] | |
table.primary_key = None | |
log.info('Creating staging table {}.{}'.format(table.schema, table.name)) | |
table.drop(engine, checkfirst=True) | |
table.create(engine) | |
return table | |
def stage_and_merge(engine, target: Table, rows: Union[Iterator, Sized]): | |
"""Write data to an intermediary staging table before adding to `target`. | |
:param engine: A instance of :class:`sqlalchemy.engine.Engine` | |
:param target: Table to write the results | |
:param rows: Data to write | |
""" | |
if isinstance(rows, Sized) and len(rows) > 0: | |
log.info('Staging Rows: {}'.format(len(rows))) | |
# Drop & recreate the staging table | |
source = create_staging_table(engine, target) | |
# Insert data into a temporary staging table prior to copying to the target | |
try: | |
bulk = BulkInsertFromIterator(source, rows, as_columns(source.columns)) | |
bulk.execute(engine.raw_connection) | |
keys = filter(lambda c: c.primary_key, target.columns) | |
where = map(lambda c: source.c[c.name] == target.c[c.name], keys) | |
# Only insert rows that do not exist in the target table | |
query = source.select().distinct().where(~exists().where(and_(*where))) | |
result = engine.execute(target.insert().from_select(source.c, query)) | |
log.info('Updated Row Count: {}'.format(result.rowcount)) | |
finally: | |
source.drop(engine) | |
def stage_and_replace(engine, target: Table, rows: Union[Iterator, Sized]): | |
if isinstance(rows, Sized) and len(rows) > 0: | |
log.info('Staging Rows: {}'.format(len(rows))) | |
# Drop & recreate the staging table | |
source = create_staging_table(engine, target) | |
# Insert data into a temporary staging table prior to copying to the target | |
try: | |
bulk = BulkInsertFromIterator(source, rows, as_columns(source.columns)) | |
bulk.execute(engine.raw_connection) | |
# Re-create the target table prior to inserting | |
if target.exists(engine): | |
target.drop(engine) | |
target.create(engine) | |
query = source.select().distinct() | |
result = engine.execute(target.insert().from_select(source.c, query)) | |
log.info('Updated Row Count: {}'.format(result.rowcount)) | |
finally: | |
source.drop(engine) | |
def determine_columns(table: Table, rows): | |
columns = as_columns(table.columns) | |
if not isinstance(rows[0], dict): | |
return columns | |
keys = rows[0].keys() | |
return list(filter(lambda c: c.name in keys, columns)) | |
def stage_and_update(engine, target: Table, rows: Union[Iterator, Sized]): | |
"""Write data to an intermediary staging table before adding to `target`. | |
:param engine: A instance of :class:`sqlalchemy.engine.Engine` | |
:param target: Table to write the results | |
:param rows: Data to write | |
""" | |
if isinstance(rows, Sized) and len(rows) > 0: | |
log.info('Staging Rows: {}'.format(len(rows))) | |
# Drop & recreate the staging table | |
source = create_staging_table(engine, target) | |
columns = determine_columns(source, rows) | |
# Insert data into a temporary staging table prior to copying to the target | |
try: | |
bulk = BulkInsertFromIterator(source, rows, columns) | |
bulk.execute(engine.raw_connection) | |
keys = filter(lambda c: c.primary_key, target.columns) | |
where = map(lambda c: source.c[c.name] == target.c[c.name], keys) | |
# Delete from target table before appending | |
delete = target.delete().where(exists().where(and_(*where))) | |
engine.execute(delete) | |
# Copy rows from staging table to target | |
insert = source.select() | |
result = engine.execute(target.insert().from_select(source.c, insert)) | |
log.info('Updated Row Count: {}'.format(result.rowcount)) | |
finally: | |
source.drop(engine) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment