-
-
Save cdpath/968547c1b765e7bf9caf17850b821285 to your computer and use it in GitHub Desktop.
sqlalchemy upsert supporting delayed ORM insertion and duplicate removal (inside a single query)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def upsert(session, model, rows): | |
table = model.__table__ | |
stmt = postgresql.insert(table) | |
primary_keys = [key.name for key in inspect(table).primary_key] | |
update_dict = {c.name: c for c in stmt.excluded if not c.primary_key} | |
if not update_dict: | |
raise ValueError("insert_or_update resulted in an empty update_dict") | |
stmt = stmt.on_conflict_do_update(index_elements=primary_keys, | |
set_=update_dict) | |
seen = set() | |
foreign_keys = {col.name: list(col.foreign_keys)[0].column for col in table.columns if col.foreign_keys} | |
unique_constraints = [c for c in table.constraints if isinstance(c, UniqueConstraint)] | |
def handle_foreignkeys_constraints(row): | |
for c_name, c_value in foreign_keys.items(): | |
foreign_obj = row.pop(c_value.table.name, None) | |
row[c_name] = getattr(foreign_obj, c_value.name) if foreign_obj else None | |
for const in unique_constraints: | |
unique = tuple([const,] + [row[col.name] for col in const.columns]) | |
if unique in seen: | |
return None | |
seen.add(unique) | |
return row | |
rows = list(filter(None, (handle_foreignkeys_constraints(row) for row in rows))) | |
session.execute(stmt, rows) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment