Closed
Description
Describe your question
Trying to setup "Joining a Session into an External Transaction (such as for test suite)" recipe for pytest
and asyncio
API
Example
sync pytest
recipe works fine
import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session")
def engine_fixture():
engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def session(engine_fixture):
conn = engine_fixture.connect()
trans = conn.begin()
session = Session(bind=conn)
def _fixture(session):
session.add_all([Thing(), Thing(), Thing()])
session.commit()
# load fixture data within the scope of the transaction
_fixture(session)
# start the session in a SAVEPOINT...
session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
@event.listens_for(session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
session.close()
trans.rollback()
conn.close()
def _test_thing(session, extra_rollback=0):
rows = session.query(Thing).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
assert len(rows) == 6
session.rollback()
# after rollbacks, still @ 3 rows
rows = session.query(Thing).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
session.commit()
rows = session.query(Thing).all()
assert len(rows) == 5
session.add(Thing())
rows = session.query(Thing).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = session.query(Thing).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
session.rollback()
rows = session.query(Thing).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
def test_thing_one_pytest(session):
# run zero rollbacks
_test_thing(session, 0)
def test_thing_two_pytest(session):
# run two extra rollbacks
_test_thing(session, 2)
port to asyncio
API, togheter with pytest-asyncio
(0.14.0) fails, due to wrong/malfunctioning teardown
of the first test-case:
import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
Base = declarative_base()
# a model
class Thing(Base):
__tablename__ = "thing"
id = Column(Integer, primary_key=True)
@pytest.fixture(scope="session", autouse=True)
def meta_migration():
# setup
sync_engine = create_engine(
"postgresql://postgres:changethis@db/app_test", echo=True
)
Base.metadata.drop_all(sync_engine)
Base.metadata.create_all(sync_engine)
yield sync_engine
# teardown
Base.metadata.drop_all(sync_engine)
@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
# setup
engine = create_async_engine(
"postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
)
yield engine
@pytest.fixture(scope="function")
async def session(async_engine):
conn = await async_engine.connect()
trans = await conn.begin()
session = AsyncSession(bind=conn)
async def _fixture(session: AsyncSession):
session.add_all([Thing(), Thing(), Thing()])
await session.commit()
# load fixture data within the scope of the transaction
await _fixture(session)
# start the session in a SAVEPOINT...
await session.begin_nested()
# then each time that SAVEPOINT ends, reopen it
# NOTE: no async listeners yet
@event.listens_for(session.sync_session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
yield session
# same teardown from the docs
await session.close()
await trans.rollback()
await conn.close()
async def _test_thing(session: AsyncSession, extra_rollback=0):
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
await session.rollback()
# after rollbacks, still @ 3 rows
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 3
session.add_all([Thing(), Thing()])
await session.commit()
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 5
session.add(Thing())
rows = (await session.execute(select(Thing))).all()
assert len(rows) == 6
for elem in range(extra_rollback):
# run N number of rollbacks
session.add_all([Thing(), Thing(), Thing()])
rows = (await session.execute(select(Thing))).all()
if elem > 0:
# b.c. we rolled back that other "thing" too
assert len(rows) == 8
else:
assert len(rows) == 9
await session.rollback()
rows = (await session.execute(select(Thing))).all()
if extra_rollback:
assert len(rows) == 5
else:
assert len(rows) == 6
@pytest.mark.asyncio
async def test_thing_one_pytest(session):
# run zero rollbacks
await _test_thing(session, 0)
@pytest.mark.asyncio
async def test_thing_two_pytest(session):
# run two extra rollbacks
await _test_thing(session, 2)
Complete stack trace, if applicable
test_async.py::test_thing_two_pytest
---------------------------------------------------------------------------------------------------------- live log setup ----------------------------------------------------------------------------------------------------------
INFO sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO sqlalchemy.engine.Engine:log.py:117 INSERT INTO thing DEFAULT VALUES RETURNING thing.id
INFO sqlalchemy.engine.Engine:log.py:117 [cached since 0.05257s ago] ()
INFO sqlalchemy.engine.Engine:log.py:117 INSERT INTO thing DEFAULT VALUES RETURNING thing.id
INFO sqlalchemy.engine.Engine:log.py:117 [cached since 0.05624s ago] ()
INFO sqlalchemy.engine.Engine:log.py:117 INSERT INTO thing DEFAULT VALUES RETURNING thing.id
INFO sqlalchemy.engine.Engine:log.py:117 [cached since 0.05979s ago] ()
INFO sqlalchemy.engine.Engine:log.py:117 COMMIT
---------------------------------------------------------------------------------------------------------- live log call -----------------------------------------------------------------------------------------------------------
INFO sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO sqlalchemy.engine.Engine:log.py:117 SAVEPOINT sa_savepoint_1
INFO sqlalchemy.engine.Engine:log.py:117 [no key 0.00043s] ()
INFO sqlalchemy.engine.Engine:log.py:117 SELECT thing.id
FROM thing
INFO sqlalchemy.engine.Engine:log.py:117 [cached since 0.05067s ago] ()
FAILED [100%]
-------------------------------------------------------------------------------------------------------- live log teardown ---------------------------------------------------------------------------------------------------------
INFO sqlalchemy.engine.Engine:log.py:117 ROLLBACK
INFO sqlalchemy.engine.Engine:log.py:117 ROLLBACK
INFO sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO sqlalchemy.engine.Engine:log.py:117 select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where pg_catalog.pg_table_is_visible(c.oid) and relname=%(name)s
INFO sqlalchemy.engine.Engine:log.py:117 [cached since 0.3944s ago] {'name': 'thing'}
INFO sqlalchemy.engine.Engine:log.py:117
DROP TABLE thing
INFO sqlalchemy.engine.Engine:log.py:117 [no key 0.00040s] {}
INFO sqlalchemy.engine.Engine:log.py:117 COMMIT
============================================================================================================= FAILURES =============================================================================================================
______________________________________________________________________________________________________ test_thing_two_pytest _______________________________________________________________________________________________________
session = <sqlalchemy.ext.asyncio.session.AsyncSession object at 0x7fea00ccea90>
@pytest.mark.asyncio
async def test_thing_two_pytest(session):
# run two extra rollbacks
> await _test_thing(session, 2)
test_async.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <sqlalchemy.ext.asyncio.session.AsyncSession object at 0x7fea00ccea90>, extra_rollback = 2
async def _test_thing(session: AsyncSession, extra_rollback=0):
rows = (await session.execute(select(Thing))).all()
> assert len(rows) == 3
E assert 8 == 3
E +8
E -3
test_async.py:75: AssertionError
Any clues on what I could be doing wrong? Worth mentioning is that restart_savepoint
is forced to use AsyncSession.sync_session
for listeners, as async listeners are not yet implemented, could it be a reason?
Versions
- OS: Debian GNU/Linux 10 (buster)
- Python: Python 3.8.5
- SQLAlchemy: 1.4@b1
- Database: postgres:12
- pytest: 6.1.2
- pytest-asyncio: 0.14.0
- DBAPI: asyncpg 0.21.0
Thanks!