Forked from gordthompson/postgresql_df_upsert.py
Last active
January 11, 2024 13:22
-
-
Save jackiep00/50eced6d1b63ac37e6841e1307007ef8 to your computer and use it in GitHub Desktop.
Build a PostgreSQL INSERT … ON CONFLICT statement and upsert a DataFrame
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# version 1.2 - 2022-10-04 | |
import pandas as pd | |
import sqlalchemy as sa | |
import string, random | |
def df_upsert(data_frame, table_name, engine, schema=None, match_columns=None): | |
""" | |
Perform an "upsert" on a PostgreSQL table from a DataFrame. | |
Constructs an INSERT … ON CONFLICT statement, uploads the DataFrame to a | |
temporary table, and then executes the INSERT. | |
Parameters | |
---------- | |
data_frame : pandas.DataFrame | |
The DataFrame to be upserted. | |
table_name : str | |
The name of the target table. | |
engine : sqlalchemy.engine.Engine | |
The SQLAlchemy Engine to use. | |
schema : str, optional | |
The name of the schema containing the target table. | |
match_columns : list of str, optional | |
A list of the column name(s) on which to match. If omitted, the | |
primary key columns of the target table will be used. | |
""" | |
table_spec = "" | |
if schema: | |
table_spec += '"' + schema.replace('"', '""') + '".' | |
table_spec += '"' + table_name.replace('"', '""') + '"' | |
df_columns = list(data_frame.columns) | |
if not match_columns: | |
insp = sa.inspect(engine) | |
match_columns = insp.get_pk_constraint(table_name, schema=schema)[ | |
"constrained_columns" | |
] | |
columns_to_update = [col for col in df_columns if col not in match_columns] | |
# generate a random temp table name | |
temp_table_name = "temp_table_" + "".join( | |
random.choice(string.ascii_lowercase) for i in range(10) | |
) | |
stmt = f"INSERT INTO {table_spec} ({insert_col_list})\n" | |
stmt += f"SELECT {insert_col_list} FROM {temp_table_name}\n" | |
match_col_list = ", ".join([f'"{col}"' for col in match_columns]) | |
stmt += f"ON CONFLICT ({match_col_list}) DO UPDATE SET\n" | |
stmt += ", ".join([f'"{col}" = EXCLUDED."{col}"' for col in columns_to_update]) | |
with engine.begin() as conn: | |
conn.exec_driver_sql(f"DROP TABLE IF EXISTS {temp_table_name}") | |
conn.exec_driver_sql( | |
f"CREATE TEMPORARY TABLE {temp_table_name} AS SELECT * FROM {table_spec} WHERE false" | |
) | |
data_frame.to_sql(temp_table_name, conn, if_exists="append", index=False) | |
conn.exec_driver_sql(stmt) | |
if __name__ == "__main__": | |
# Usage example adapted from | |
# https://stackoverflow.com/a/62379384/2144390 | |
engine = sa.create_engine("postgresql://scott:tiger@192.168.0.199/test") | |
# create example environment | |
with engine.begin() as conn: | |
conn.exec_driver_sql("DROP TABLE IF EXISTS main_table") | |
conn.exec_driver_sql( | |
"CREATE TABLE main_table (id int primary key, txt varchar(50), status varchar(50))" | |
) | |
conn.exec_driver_sql( | |
"INSERT INTO main_table (id, txt, status) VALUES (1, 'row 1 old text', 'original')" | |
) | |
# [(1, 'row 1 old text', 'original')] | |
# DataFrame to upsert | |
df = pd.DataFrame( | |
[(2, "new row 2 text", "upserted"), (1, "row 1 new text", "upserted")], | |
columns=["id", "txt", "status"], | |
) | |
df_upsert(df, "main_table", engine) | |
"""The INSERT statement generated for this example: | |
INSERT INTO "main_table" ("id", "txt", "status") | |
SELECT "id", "txt", "status" FROM temp_table | |
ON CONFLICT ("id") DO UPDATE SET | |
"txt" = EXCLUDED."txt", "status" = EXCLUDED."status" | |
""" | |
# check results | |
with engine.begin() as conn: | |
print( | |
conn.exec_driver_sql("SELECT * FROM main_table").all() | |
) | |
# [(2, 'new row 2 text', 'upserted'), (1, 'row 1 new text', 'upserted')] |
Thanks for the script! When I tried to use it, line #44 has an undefined variable - insert_col_list
.
In the original script, it was defined as:
insert_col_list = ", ".join([f'"{col_name}"' for col_name in df_columns])
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Awesome!
I think data_frame.to_sql can be faster if you pass method=multi, might be worth using/documenting?