-
-
Save zzzeek/8443477 to your computer and use it in GitHub Desktop.
from sqlalchemy import Column, Integer, create_engine | |
from sqlalchemy.ext.declarative import declarative_base | |
Base = declarative_base() | |
# a model | |
class Thing(Base): | |
__tablename__ = 'thing' | |
id = Column(Integer, primary_key=True) | |
# a database w a schema | |
engine = create_engine("postgresql://scott:tiger@localhost/test", echo=True) | |
Base.metadata.drop_all(engine) | |
Base.metadata.create_all(engine) | |
from unittest import TestCase | |
import unittest | |
from sqlalchemy.orm import Session | |
from sqlalchemy import event | |
class MyTests(TestCase): | |
def setUp(self): | |
# same setup from the docs | |
self.conn = engine.connect() | |
self.trans = self.conn.begin() | |
self.session = Session(bind=self.conn) | |
# load fixture data within the scope of the transaction | |
self._fixture() | |
# start the session in a SAVEPOINT... | |
self.session.begin_nested() | |
# then each time that SAVEPOINT ends, reopen it | |
@event.listens_for(self.session, "after_transaction_end") | |
def restart_savepoint(session, transaction): | |
if transaction.nested and not transaction._parent.nested: | |
session.begin_nested() | |
def tearDown(self): | |
# same teardown from the docs | |
self.session.close() | |
self.trans.rollback() | |
self.conn.close() | |
def _fixture(self): | |
self.session.add_all([ | |
Thing(), Thing(), Thing() | |
]) | |
self.session.commit() | |
def test_thing_one(self): | |
# run zero rollbacks | |
self._test_thing(0) | |
def test_thing_two(self): | |
# run two extra rollbacks | |
self._test_thing(2) | |
def test_thing_five(self): | |
# run five extra rollbacks | |
self._test_thing(5) | |
def _test_thing(self, extra_rollback=0): | |
session = self.session | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 3) | |
for elem in range(extra_rollback): | |
# run N number of rollbacks | |
session.add_all([Thing(), Thing(), Thing()]) | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 6) | |
session.rollback() | |
# after rollbacks, still @ 3 rows | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 3) | |
session.add_all([Thing(), Thing()]) | |
session.commit() | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 5) | |
session.add(Thing()) | |
rows = session.query(Thing).all() | |
self.assertEquals(len(rows), 6) | |
for elem in range(extra_rollback): | |
# run N number of rollbacks | |
session.add_all([Thing(), Thing(), Thing()]) | |
rows = session.query(Thing).all() | |
if elem > 0: | |
# b.c. we rolled back that other "thing" too | |
self.assertEquals(len(rows), 8) | |
else: | |
self.assertEquals(len(rows), 9) | |
session.rollback() | |
rows = session.query(Thing).all() | |
if extra_rollback: | |
self.assertEquals(len(rows), 5) | |
else: | |
self.assertEquals(len(rows), 6) | |
if __name__ == '__main__': | |
unittest.main() | |
@zzzeek how would you do this with the 2.0 api, connections are implicitly begun now.
the official version of this recipe is in the docs at https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
connections implicitly begin only if they aren't manually begun first. both options are available. the 2.0 docs have a breakdown of the new calling styles: https://docs.sqlalchemy.org/en/20/core/connections.html#using-transactions
@zzzeek thanks for clarifying!
The issue is, I have my sessionmaker in a database.py
file, like this:
# database.py
from sqlalchemy import create_engine, MetaData
from sqlalchemy.orm import (
declarative_base,
scoped_session,
sessionmaker
)
from app.config import DEBUG, DATABASE_URL
metadata = MetaData(
naming_convention={
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
},
)
Base = declarative_base(metadata=metadata)
engine = create_engine(url=DATABASE_URL, future=True, echo=DEBUG)
db_session = scoped_session(sessionmaker(bind=engine))
And I have setup my pytest fixtures like this:
@fixture(scope="session")
def db_engine() -> Iterator[Engine]:
alembic_cfg = Config("alembic.ini")
Base.metadata.create_all(bind=engine)
stamp(alembic_cfg, revision="head")
yield engine
Base.metadata.drop_all(bind=engine)
stamp(alembic_cfg, revision=None, purge=True)
@fixture(scope="session")
def db_connection(db_engine: Engine) -> Iterator[Connection]:
"""
Initializes the connection to
the test database.
:return: The database connection.
"""
connection = db_engine.connect()
yield connection
connection.close()
@fixture(autouse=True)
def db_transaction(db_connection: Connection) -> Iterator[Session]:
"""
Sets up a database transaction for each test case.
:return: The database transaction.
"""
transaction = db_connection.begin()
session = db_session(bind=db_connection)
yield session
session.close()
transaction.rollback()
I have a couple of problems with this:
- I get an error when this line:
transaction = db_connection.begin()
is called, probably because the connection has already begun at that point.
- I have my session inside a transaction, but how do I supply that to the rest of the project (views, services..)?
I could think of patching the session present indatabase.py
, but is there a better option, like changing the bind
of the sessionmaker?
i see nothing wrong with the code and im not able to spot any place that the connection would be implicitly beginning, so you should be able to call begin() no problem.
you are already using the global scoped session you have in your fixture, so while that fixture is in effect, that's the session that all of your application will get when they refer to the db_session global.
looks like you are missing a db_session.remove(), otherwise works fine, heres a demo, try running this:
from typing import Iterator
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from pytest import fixture
engine = create_engine("sqlite://", echo=True, future=True)
db_session = scoped_session(sessionmaker(bind=engine))
@fixture(scope="session")
def db_engine() -> Iterator[Engine]:
yield engine
@fixture(scope="session")
def db_connection(db_engine: Engine) -> Iterator[Connection]:
"""
Initializes the connection to
the test database.
:return: The database connection.
"""
connection = db_engine.connect()
yield connection
connection.close()
@fixture(autouse=True)
def db_transaction(db_connection: Connection) -> Iterator[Session]:
"""
Sets up a database transaction for each test case.
:return: The database transaction.
"""
transaction = db_connection.begin()
assert transaction is not None
session = db_session(bind=db_connection)
yield session
session.close()
db_session.remove()
transaction.rollback()
def test_in_a_transaction(db_transaction):
result = db_transaction.execute(text("select 1"))
result.close()
def test_also_in_a_transaction(db_transaction):
result = db_transaction.execute(text("select 1"))
result.close()
Nicely done!