Last active
September 22, 2020 12:55
-
-
Save osmanmesutozcan/0fcc0ab0de6d9445f39e863c668593cb to your computer and use it in GitHub Desktop.
Django ORM manager for (naive) bulk upserting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db import models | |
class BulkUpsertManager(models.Manager): | |
@staticmethod | |
def __filter_field(field): | |
return not field.many_to_many | |
@staticmethod | |
def __to_serialized(field): | |
if isinstance(field, datetime): | |
return f"'{str(datetime)}'" | |
if isinstance(field, bool): | |
return str(1 if field else 0) | |
if isinstance(field, int): | |
return str(field) | |
return f"{field}" | |
@staticmethod | |
def __to_value_tuple_list(value, all_fields): | |
return [BulkUpsertManager.__to_serialized(getattr(value, v.name)) for v in all_fields if | |
BulkUpsertManager.__filter_field(v)] | |
def bulk_upsert(self, values, update_on_conflict=None): | |
if not update_on_conflict: | |
return [] | |
table_name = self.model._meta.db_table | |
model_fields = self.model._meta.get_fields() | |
# XXX: This is not how we should handle conflicting keys. | |
pk_field = [f.name for f in model_fields if f.primary_key] | |
all_fields = [f.name for f in model_fields if BulkUpsertManager.__filter_field(f)] | |
set_fields = [f"{f}=EXCLUDED.{f}" for f in update_on_conflict] | |
values_template = f"({', '.join(['%s'] * len(all_fields))})" | |
sql_query = f""" | |
INSERT INTO {table_name} ({', '.join(all_fields)}) | |
VALUES {', '.join([values_template] * len(values))} | |
ON CONFLICT ({', '.join(pk_field)}) DO UPDATE SET {', '.join(set_fields)} | |
""" | |
from django.db import connection | |
with connection.cursor() as cursor: | |
flatten = lambda l: [item for sublist in l for item in sublist] | |
params = [BulkUpsertManager.__to_value_tuple_list(m, model_fields) for m in values] | |
cursor.execute(sql_query, flatten(params)) | |
return cursor.rowcount |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment