Last active
February 27, 2024 12:45
-
-
Save Evgenus/5d9b279c396d414cfcd61814c8417058 to your computer and use it in GitHub Desktop.
Proper SQLAlchemy transactions example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager | |
import threading | |
from thread import get_ident | |
from sqlalchemy import * | |
from sqlalchemy.orm import * | |
from sqlalchemy.ext.declarative import declarative_base | |
from flask import Flask | |
from flask_sqlalchemy import SQLAlchemy | |
################################################################################ | |
def scopefunc(): | |
return "{}.{}".format(is_inside_txn(), get_ident()) | |
db = SQLAlchemy( | |
session_options=dict( | |
scopefunc=scopefunc | |
) | |
) | |
txn_context = threading.local() | |
context_name = "db_or_conection_instance_id" | |
def push_txn_context(): | |
if not hasattr(txn_context, context_name): | |
setattr(txn_context, context_name, []) | |
getattr(txn_context, context_name).append(True) | |
def pop_txn_context(): | |
getattr(txn_context, context_name).pop(-1) | |
def is_inside_txn(): | |
return len(getattr(txn_context, context_name, [])) > 0 | |
@contextmanager | |
def transaction(): | |
outer_transaction = not is_inside_txn() | |
push_txn_context() | |
if outer_transaction: | |
db.session.close() | |
else: | |
db.session.begin(nested=True) | |
try: | |
yield | |
db.session.commit() | |
except Exception: | |
db.session.rollback() | |
raise | |
finally: | |
if outer_transaction: | |
db.session.close() | |
pop_txn_context() | |
def transactional(func): | |
def wrapper(*args, **kwargs): | |
with transaction(): | |
return func(*args, **kwargs) | |
return wrapper | |
################################################################################ | |
Base = declarative_base() | |
def print_what_is_in_session(message): | |
values = [obj.value for obj in db.session.query(SomeObject).all()] | |
print message, values | |
class SomeObject(db.Model): | |
__tablename__ = 'a' | |
id = Column(Integer, primary_key=True) | |
value = Column(String(128)) | |
@transactional | |
def A(): | |
db.session.add(SomeObject(value="a1")) | |
print "calling function B from the scope of A" | |
try: | |
B() | |
except Exception: | |
pass | |
# here everythin created inside B (but not in A) should be removed from DB | |
print_what_is_in_session("after B rolled-back: ") | |
db.session.add(SomeObject(value="a2")) | |
@transactional | |
def B(): | |
db.session.add(SomeObject(value="b1")) | |
C() | |
db.session.add(SomeObject(value="b2")) | |
@transactional | |
def C(): | |
db.session.add(SomeObject(value="c")) | |
# here we can see that all object created before that moment are in session | |
print_what_is_in_session("before exception: ") | |
raise ValueError(0) | |
################################################################################ | |
def scenario(): | |
# calling transactional function with maybe nested transactions | |
print "calling function A from outer scope" | |
db.session.add(SomeObject(value="i")) | |
db.session.flush() | |
A() | |
print_what_is_in_session("after calling A: ") | |
print "calling function B from outer scope" | |
try: | |
B() | |
except Exception: | |
pass | |
# As long as B was rolled back and catched in A only things created | |
# inside A was stored into DB | |
print_what_is_in_session("at the end: ") | |
if __name__ == "__main__": | |
app = Flask("test_app") | |
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://root:root@localhost/test' | |
app.config['SQLALCHEMY_ECHO'] = True | |
db.init_app(app) | |
with app.app_context(): | |
db.create_all() | |
scenario() | |
db.session.close() | |
db.drop_all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager | |
import threading | |
from sqlalchemy import * | |
from sqlalchemy.orm import * | |
from sqlalchemy.ext.declarative import declarative_base | |
################################################################################ | |
txn_context = threading.local() | |
context_name = "db_or_conection_instance_id" | |
def push_txn_context(session): | |
if not hasattr(txn_context, context_name): | |
setattr(txn_context, context_name, []) | |
getattr(txn_context, context_name).append(session) | |
def pop_txn_context(): | |
getattr(txn_context, context_name).pop(-1) | |
def is_inside_txn(): | |
return len(getattr(txn_context, context_name, [])) > 0 | |
# this could be reduced to some property of thread-local object (like db.session) | |
def current_session(): | |
session_stack = getattr(txn_context, context_name, []) | |
if len(session_stack) > 0: | |
return session_stack[-1] | |
@contextmanager | |
def transaction(): | |
if is_inside_txn(): | |
session = current_session() | |
session.begin_nested() | |
else: | |
session = Session(engine) | |
push_txn_context(session) | |
try: | |
yield | |
session.commit() | |
except Exception: | |
session.rollback() | |
raise | |
finally: | |
pop_txn_context() | |
if not is_inside_txn(): | |
session.close() | |
def transactional(func): | |
def wrapper(*args, **kwargs): | |
with transaction(): | |
return func(*args, **kwargs) | |
return wrapper | |
################################################################################ | |
Base = declarative_base() | |
def print_what_is_in_session(message, session): | |
values = [obj.value for obj in session.query(SomeObject).all()] | |
print message, values | |
class SomeObject(Base): | |
__tablename__ = 'a' | |
id = Column(Integer, primary_key=True) | |
value = Column(String(128)) | |
@transactional | |
def A(): | |
session = current_session() | |
session.add(SomeObject(value="a1")) | |
try: | |
B() | |
except Exception: | |
pass | |
# here everythin created inside B (but not in A) should be removed from DB | |
print_what_is_in_session("after B rolled-back: ", session) | |
session.add(SomeObject(value="a2")) | |
@transactional | |
def B(): | |
session = current_session() | |
session.add(SomeObject(value="b1")) | |
C() | |
session.add(SomeObject(value="b2")) | |
@transactional | |
def C(): | |
session = current_session() | |
session.add(SomeObject(value="c")) | |
# here we can see that all object created before that moment are in session | |
print_what_is_in_session("before exception: ", session) | |
raise ValueError(0) | |
################################################################################ | |
if __name__ == "__main__": | |
engine = create_engine('mysql://root:root@localhost/test', echo=True) | |
Base.metadata.create_all(engine) | |
connection = engine.connect() | |
# Existing external session created by something (like Flask-SQLAlchemy) | |
session = Session(engine) | |
# calling transactional function with maybe nested transactions | |
print "calling function A from outer scope" | |
A() | |
print_what_is_in_session("after calling A: ", session) | |
print "calling function B from outer scope" | |
try: | |
B() | |
except Exception: | |
pass | |
# As long as B was rolled back and catched in A only things created | |
# inside A was stored into DB | |
print_what_is_in_session("at the end: ", session) | |
session.close() | |
Base.metadata.drop_all(engine) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment