Last active
September 11, 2016 03:04
-
-
Save markgajdosik/a04188e3515987bf5fb45d73d132cf60 to your computer and use it in GitHub Desktop.
Update-returning support for Django 1.8. Based on: https://github.com/kanu/django-update-returning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implements manager with "UPDATE .. RETURNING" clause support. Columns following | |
the "RETURNING" clause may be provided using `values_list()` method. If not | |
specified, the update statement will yield full records. | |
Queries executed with use of `UpdateReturningManager` or | |
`UpdateReturningDefaultManager` return a generator allowing easy iterating over | |
the returned entries. | |
Based on https://github.com/kanu/django-update-returning, with modifications to | |
provide Django 1.8 support. | |
""" | |
from django.db import connections, transaction | |
from django.db.models.manager import Manager | |
from django.db.models.query import QuerySet, ValuesListQuerySet, ValuesQuerySet | |
from django.db.models.query_utils import deferred_class_factory | |
from django.db.models.sql import UpdateQuery | |
from django.db.models.sql.compiler import SQLUpdateCompiler | |
from django.db.models.sql.constants import MULTI | |
class SQLUpdateReturningCompiler(SQLUpdateCompiler): | |
""" | |
Based on built-in `SQLUpdateCompiler`. Adds "RETURNING" clause to the | |
standard SQL "UPDATE" statement. Columns in the "RETURNING" clause are | |
controlled via `values_list()`. | |
""" | |
col_count = None # (required) | |
def as_sql(self): | |
# Get SQL and parameters for a typical "UPDATE" statement. | |
sql, params = super(SQLUpdateReturningCompiler, self).as_sql() | |
# Add the "RETURNING" clause. | |
sql = '%s RETURNING %s' % (sql.rstrip(), ', '.join( | |
self.get_returning_columns())) | |
return sql, params | |
def get_returning_columns(self): | |
""" | |
:return (list): Column names specified via `values_list()`, if any. | |
""" | |
return [c[1][0] for c in self.get_select()[0]] | |
def execute_sql(self, result_type): | |
return super(SQLUpdateCompiler, self).execute_sql(result_type) | |
class UpdateReturningQuery(UpdateQuery): | |
""" "UPDATE" query with PostgreSQL "RETURNING" support. """ | |
compiler_class = SQLUpdateReturningCompiler | |
def get_compiler(self, using=None, connection=None): | |
if using is None and connection is None: | |
raise ValueError('Need either using or connection') | |
if using: | |
connection = connections[using] | |
return self.compiler_class(self, connection, using) | |
class UpdateReturningMethods(object): | |
""" | |
Extends querysets with methods to return rows from sql updates. | |
""" | |
_for_write = None | |
def _clone(self, klass=None, setup=False, **kwargs): | |
""" Changing a given class to the matching "UpdateReturning" one. """ | |
overwrites = {'QuerySet': UpdateReturningQuerySet, | |
'ValuesQuerySet': UpdateReturningValuesQuerySet, | |
'ValuesListQuerySet': UpdateReturningValuesListQuerySet} | |
if klass and klass.__name__ in overwrites: | |
klass = overwrites[klass.__name__] | |
return super(UpdateReturningMethods, self)._clone( | |
klass, setup, **kwargs) | |
def update_returning(self, **kwargs): | |
""" | |
An update that returns the rows that have been updated as an iterator. | |
The type of the return objects can be handled by preciding queryset | |
methods like in normal querysets. | |
Preciding methods that change the type of result items are "only", | |
"defer", "values_list" and "values", if none those is used the result | |
items will be full model instances. For example a model.objects.values_list('id',flat=True).update_returning(published=True) | |
will return a iterator with the ids of the changed objects. | |
""" | |
self._for_write = True | |
query = self.query.clone(UpdateReturningQuery) | |
query.add_update_values(kwargs) | |
cursor = query.get_compiler(self.db).execute_sql(MULTI) | |
transaction.commit(using=self.db) | |
self._result_cache = None | |
result_factory = self._returning_update_result_factory() | |
for rows in cursor: | |
for row in rows: | |
yield result_factory(row) | |
def update_returning_list(self, **kwargs): | |
return list(self.update_returning(**kwargs)) | |
def _returning_update_result_factory(self): | |
return lambda x: x | |
class UpdateReturningQuerySet(UpdateReturningMethods, QuerySet): | |
def _returning_update_result_factory(self): | |
""" returns a mapper function to convert the iterated rows into model instances | |
or defered models instance depending on the use of "only" or "defer" | |
""" | |
fill_cache = False # always False for now! | |
only_load = self.query.get_loaded_field_names() | |
fields = self.model._meta.fields | |
load_fields = [] | |
if only_load: | |
for field, model in self.model._meta.get_fields_with_model(): | |
if model is None: | |
model = self.model | |
try: | |
if field.name in only_load[model]: | |
# Add a field that has been explicitly included | |
load_fields.append(field.name) | |
except KeyError: | |
# Model wasn't explicitly listed in the only_load table | |
# Therefore, we need to load all fields from this model | |
load_fields.append(field.name) | |
skip = None | |
if load_fields: | |
skip = set() | |
init_list = [] | |
for field in fields: | |
if field.name not in load_fields: | |
skip.add(field.attname) | |
else: | |
init_list.append(field.attname) | |
model_cls = deferred_class_factory(self.model, skip) | |
assert self._for_write, "_for_write must be True" | |
db = self.db | |
if skip: | |
factory = lambda row: model_cls(**dict(zip(init_list, row))) | |
else: | |
model = self.model | |
factory = lambda row: model(*row) | |
def mapper(row): | |
obj = factory(row) | |
obj._state.db = db | |
obj._state.adding = False | |
return obj | |
return mapper | |
class UpdateReturningValuesQuerySet(UpdateReturningMethods, ValuesQuerySet): | |
def _returning_update_result_factory(self): | |
field_names = self.field_names | |
return lambda x: dict(zip(field_names, x)) | |
class UpdateReturningValuesListQuerySet( | |
UpdateReturningMethods, ValuesListQuerySet): | |
flat = None | |
def _returning_update_result_factory(self): | |
if self.flat and len(self._fields) == 1: | |
return lambda x: x[0] | |
else: | |
return tuple | |
class UpdateReturningManager(Manager): | |
""" A manager that uses the `UpdateReturningQuerySet`. """ | |
def get_queryset(self): | |
return UpdateReturningQuerySet(self.model, using=self._db) | |
def update_returning(self, **kwargs): | |
return self.get_queryset().update_returning(**kwargs) | |
class UpdateReturningDefaultManager(UpdateReturningManager): | |
""" | |
A manager that uses the `UpdateReturningQuerySet` that will be used for | |
accessing related objects. | |
""" | |
use_for_related_fields = True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment