Last active
January 17, 2022 14:44
-
-
Save laroo/3eaf054c2b3c4daa6de494531c2b0585 to your computer and use it in GitHub Desktop.
Python/Flask: BulkModelSave & PaginatedQuery: Context managers that will speedup insert/update & iterating a lot of rows
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import, division, print_function, unicode_literals | |
from builtins import object # Python3 Compatibility | |
import logging | |
import math | |
from typing import Text, Dict, Optional | |
from datetime import datetime | |
from decimal import Decimal | |
import sqlalchemy | |
from sqlalchemy import and_, func | |
from tiqets.extensions import db, db_bind | |
from tiqets.model.sa360 import SA360Account, SA360Campaign, SA360AdGroup, SA360Ad | |
from tiqets.model.channels import SalesChannel | |
from tiqets.model.product import Product, ProductVariant | |
from tiqets.pricing.variant import get_variant_price_indication | |
log = logging.getLogger(__name__) | |
class BulkModelSave(object): | |
""" | |
Bulk saves SQLAlchemy objects in a handy context manager | |
Because separate INSERTS/UPDATES are slow! | |
Usage: | |
with BulkModelSave(auto_flush=500) as bulk_model_save: | |
for my_model in MyModel.query.filter(): | |
my_model.field = new_value | |
bulk_model_save.save(my_model) | |
""" | |
def __init__(self, auto_flush=1000): | |
# type (int) -> None | |
self._auto_flush = auto_flush | |
self._db_objects = [] | |
def __enter__(self): | |
self._db_objects = [] | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.flush() | |
@property | |
def num_objects(self): | |
return len(self._db_objects) | |
def save(self, db_object): | |
self._db_objects.append(db_object) | |
if self.num_objects >= self._auto_flush: | |
self.flush() | |
def flush(self): | |
if self.num_objects > 0: | |
with db_bind('master'): | |
db.session.bulk_save_objects(self._db_objects) | |
db.session.commit() | |
self._db_objects = [] | |
class PaginatedQuery(object): | |
""" | |
Breaks a Query into multiple pages based on a given unique/PK column. | |
Why: Looping over a lot of records is slow, also when using the SQLAlchemy iterator. Simple solution would | |
be to use Flask pagination (or create your own) which involves OFFSET & LIMIT. However due to OFFSET it | |
will become increasingly slower as OFFSET requires the DB to scan through all the previous rows each | |
time in order to get to the requested row | |
For more info (and inspiration) see: https://github.com/sqlalchemy/sqlalchemy/wiki/WindowedRangeQuery | |
Examples: | |
Loop per page and loop over each record: | |
``` | |
with PaginatedQuery(SA360AdGroup.query.filter(), SA360AdGroup.id, 1000) as paginated_query: | |
progress_bar = Bar('Writing ad groups', max=paginated_query.num_pages) | |
for page_queryset in paginated_query.get_query_per_page(): | |
next(progress_bar) | |
for ad_group in page_queryset: | |
apply_some_action(ad_group) | |
``` | |
Directly loop over each record (pagination used internally): | |
``` | |
with PaginatedQuery(SA360AdGroup.query.filter(), SA360AdGroup.id, 1000) as paginated_query: | |
progress_bar = Bar('Writing ad groups', max=paginated_query.num_records) | |
for ad_group in paginated_query.get_results(): | |
next(progress_bar) | |
apply_some_action(ad_group) | |
``` | |
""" | |
def __init__(self, query, pk_column, items_per_page=1000): | |
# type: (sqlalchemy.orm.query.Query, sqlalchemy.Column, int) -> None | |
self._query = query | |
self._pk_column = pk_column | |
self._items_per_page = items_per_page | |
self._num_records = None # type: Optional[int] | |
self._num_pages = None # type: Optional[int] | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
pass | |
@property | |
def items_per_page(self): | |
# type: () -> int | |
return self._items_per_page | |
@property | |
def num_records(self): | |
# type: () -> int | |
if self._num_records is None: | |
self._num_records = self._query.count() | |
return self._num_records | |
@property | |
def num_pages(self): | |
# type: () -> int | |
if self._num_pages is None: | |
self._num_pages = int(math.ceil(self.num_records / self.items_per_page)) | |
return self._num_pages | |
def get_where_statement_per_page(self): | |
""" | |
This is where the magic happens! Return a series of WHERE clauses against a given unique/PK column | |
that break the query it into pages. | |
Parameters: | |
* self._query: SA360AdGroup.query.filter(SA360AdGroup.status == 'Active') | |
* self._pk_column: SA360AdGroup.id | |
* self._items_per_page: 1000 | |
Will result in this query: | |
``` | |
SELECT | |
anon_1.marketing_sa360_adgroup_id AS current_start_id, | |
lead(anon_1.marketing_sa360_adgroup_id) OVER (ORDER BY anon_1.marketing_sa360_adgroup_id) AS next_id | |
FROM ( | |
SELECT | |
marketing.sa360_adgroup.id AS marketing_sa360_adgroup_id, | |
row_number() OVER (ORDER BY marketing.sa360_adgroup.id) AS row_num | |
FROM marketing.sa360_adgroup | |
WHERE marketing.sa360_adgroup.status = 'Active' | |
) AS anon_1 WHERE row_num % 1000 = 1 | |
``` | |
Query result (each row will be a page with a from and till column value): | |
``` | |
current_id next_id | |
974 72088 | |
72088 87621 | |
87621 149834 | |
149834 170988 | |
170988 228123 | |
228123 234235 | |
234235 248975 | |
[...] | |
718407 720956 | |
720956 [NULL] | |
``` | |
The result can be used as a where statement: `self._query.filter(where_clause)` | |
Based on this example: | |
``` | |
SA360AdGroup.query.filter(SA360AdGroup.status == 'Active', SA360AdGroup.id >= 974, SA360AdGroup.id < 72088) | |
``` | |
""" | |
q = ( | |
self._query.with_entities( | |
self._pk_column, func.row_number().over(order_by=self._pk_column).label('row_num') | |
) | |
.from_self( | |
self._pk_column.label('current_start_id'), | |
func.lead(self._pk_column).over(order_by=self._pk_column).label('next_start_id'), | |
) | |
.order_by(self._pk_column) | |
) | |
if self.items_per_page > 1: | |
q = q.filter(sqlalchemy.text("row_num %% %d = 1" % self.items_per_page)) | |
intervals = q.all() | |
for interval in intervals: | |
if interval.next_start_id: | |
yield and_(self._pk_column >= interval.current_start_id, self._pk_column < interval.next_start_id) | |
else: | |
yield and_(self._pk_column >= interval.current_start_id) | |
def get_query_per_page(self): | |
for where_clause in self.get_where_statement_per_page(): | |
yield self._query.filter(where_clause).order_by(self._pk_column) | |
def get_results(self): | |
for page_queryset in self.get_query_per_page(): | |
for row in page_queryset: | |
yield row |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment