Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Joining a Session into an External Transaction (async API) #5811

Copy link
Copy link
Closed
@PhillCli

Description

@PhillCli
Issue body actions

Describe your question

Trying to setup "Joining a Session into an External Transaction (such as for test suite)" recipe for pytest and asyncio API

Example
sync pytest recipe works fine

import pytest
from sqlalchemy.orm import Session
from sqlalchemy import event, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session")
def engine_fixture():
    engine = create_engine("postgresql://postgres:changethis@db/app_test", echo=True)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)

    yield engine

    Base.metadata.drop_all(engine)


@pytest.fixture
def session(engine_fixture):
    conn = engine_fixture.connect()
    trans = conn.begin()
    session = Session(bind=conn)

    def _fixture(session):
        session.add_all([Thing(), Thing(), Thing()])
        session.commit()

    # load fixture data within the scope of the transaction
    _fixture(session)

    # start the session in a SAVEPOINT...
    session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    @event.listens_for(session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    session.close()
    trans.rollback()
    conn.close()


def _test_thing(session, extra_rollback=0):

    rows = session.query(Thing).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        assert len(rows) == 6

        session.rollback()

    # after rollbacks, still @ 3 rows
    rows = session.query(Thing).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    session.commit()

    rows = session.query(Thing).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = session.query(Thing).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = session.query(Thing).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        session.rollback()

    rows = session.query(Thing).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


def test_thing_one_pytest(session):
    # run zero rollbacks
    _test_thing(session, 0)


def test_thing_two_pytest(session):
    # run two extra rollbacks
    _test_thing(session, 2)

port to asyncio API, togheter with pytest-asyncio (0.14.0) fails, due to wrong/malfunctioning teardown of the first test-case:

import pytest
from sqlalchemy import Column, Integer, create_engine, event
from sqlalchemy.future import select
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

Base = declarative_base()

# a model
class Thing(Base):
    __tablename__ = "thing"

    id = Column(Integer, primary_key=True)


@pytest.fixture(scope="session", autouse=True)
def meta_migration():
    # setup
    sync_engine = create_engine(
        "postgresql://postgres:changethis@db/app_test", echo=True
    )
    Base.metadata.drop_all(sync_engine)
    Base.metadata.create_all(sync_engine)

    yield sync_engine

    # teardown
    Base.metadata.drop_all(sync_engine)


@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
    # setup
    engine = create_async_engine(
        "postgresql+asyncpg://postgres:changethis@db/app_test", echo=True
    )

    yield engine


@pytest.fixture(scope="function")
async def session(async_engine):
    conn = await async_engine.connect()
    trans = await conn.begin()
    session = AsyncSession(bind=conn)

    async def _fixture(session: AsyncSession):
        session.add_all([Thing(), Thing(), Thing()])
        await session.commit()

    # load fixture data within the scope of the transaction
    await _fixture(session)

    # start the session in a SAVEPOINT...
    await session.begin_nested()

    # then each time that SAVEPOINT ends, reopen it
    # NOTE: no async listeners yet
    @event.listens_for(session.sync_session, "after_transaction_end")
    def restart_savepoint(session, transaction):
        if transaction.nested and not transaction._parent.nested:
            session.begin_nested()

    yield session

    # same teardown from the docs
    await session.close()
    await trans.rollback()
    await conn.close()


async def _test_thing(session: AsyncSession, extra_rollback=0):

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        assert len(rows) == 6

        await session.rollback()

    # after rollbacks, still @ 3 rows
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 3

    session.add_all([Thing(), Thing()])
    await session.commit()

    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 5

    session.add(Thing())
    rows = (await session.execute(select(Thing))).all()
    assert len(rows) == 6

    for elem in range(extra_rollback):
        # run N number of rollbacks
        session.add_all([Thing(), Thing(), Thing()])
        rows = (await session.execute(select(Thing))).all()
        if elem > 0:
            # b.c. we rolled back that other "thing" too
            assert len(rows) == 8
        else:
            assert len(rows) == 9
        await session.rollback()

    rows = (await session.execute(select(Thing))).all()
    if extra_rollback:
        assert len(rows) == 5
    else:
        assert len(rows) == 6


@pytest.mark.asyncio
async def test_thing_one_pytest(session):
    # run zero rollbacks
    await _test_thing(session, 0)


@pytest.mark.asyncio
async def test_thing_two_pytest(session):
    # run two extra rollbacks
    await _test_thing(session, 2)

Complete stack trace, if applicable

test_async.py::test_thing_two_pytest 
---------------------------------------------------------------------------------------------------------- live log setup ----------------------------------------------------------------------------------------------------------
INFO     sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO     sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO     sqlalchemy.engine.Engine:log.py:117 INSERT INTO thing DEFAULT VALUES RETURNING thing.id
INFO     sqlalchemy.engine.Engine:log.py:117 [cached since 0.05257s ago] ()
INFO     sqlalchemy.engine.Engine:log.py:117 INSERT INTO thing DEFAULT VALUES RETURNING thing.id
INFO     sqlalchemy.engine.Engine:log.py:117 [cached since 0.05624s ago] ()
INFO     sqlalchemy.engine.Engine:log.py:117 INSERT INTO thing DEFAULT VALUES RETURNING thing.id
INFO     sqlalchemy.engine.Engine:log.py:117 [cached since 0.05979s ago] ()
INFO     sqlalchemy.engine.Engine:log.py:117 COMMIT
---------------------------------------------------------------------------------------------------------- live log call -----------------------------------------------------------------------------------------------------------
INFO     sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO     sqlalchemy.engine.Engine:log.py:117 SAVEPOINT sa_savepoint_1
INFO     sqlalchemy.engine.Engine:log.py:117 [no key 0.00043s] ()
INFO     sqlalchemy.engine.Engine:log.py:117 SELECT thing.id 
FROM thing
INFO     sqlalchemy.engine.Engine:log.py:117 [cached since 0.05067s ago] ()
FAILED                                                                                                                                                                                                                       [100%]
-------------------------------------------------------------------------------------------------------- live log teardown ---------------------------------------------------------------------------------------------------------
INFO     sqlalchemy.engine.Engine:log.py:117 ROLLBACK
INFO     sqlalchemy.engine.Engine:log.py:117 ROLLBACK
INFO     sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO     sqlalchemy.engine.Engine:log.py:117 select relname from pg_class c join pg_namespace n on n.oid=c.relnamespace where pg_catalog.pg_table_is_visible(c.oid) and relname=%(name)s
INFO     sqlalchemy.engine.Engine:log.py:117 [cached since 0.3944s ago] {'name': 'thing'}
INFO     sqlalchemy.engine.Engine:log.py:117 
DROP TABLE thing
INFO     sqlalchemy.engine.Engine:log.py:117 [no key 0.00040s] {}
INFO     sqlalchemy.engine.Engine:log.py:117 COMMIT


============================================================================================================= FAILURES =============================================================================================================
______________________________________________________________________________________________________ test_thing_two_pytest _______________________________________________________________________________________________________

session = <sqlalchemy.ext.asyncio.session.AsyncSession object at 0x7fea00ccea90>

    @pytest.mark.asyncio
    async def test_thing_two_pytest(session):
        # run two extra rollbacks
>       await _test_thing(session, 2)

test_async.py:126: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

session = <sqlalchemy.ext.asyncio.session.AsyncSession object at 0x7fea00ccea90>, extra_rollback = 2

    async def _test_thing(session: AsyncSession, extra_rollback=0):
    
        rows = (await session.execute(select(Thing))).all()
>       assert len(rows) == 3
E       assert 8 == 3
E         +8
E         -3

test_async.py:75: AssertionError

Any clues on what I could be doing wrong? Worth mentioning is that restart_savepoint is forced to use AsyncSession.sync_session for listeners, as async listeners are not yet implemented, could it be a reason?

Versions

  • OS: Debian GNU/Linux 10 (buster)
  • Python: Python 3.8.5
  • SQLAlchemy: 1.4@b1
  • Database: postgres:12
  • pytest: 6.1.2
  • pytest-asyncio: 0.14.0
  • DBAPI: asyncpg 0.21.0

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    asynciobugSomething isn't workingSomething isn't workinggreat mcveAn issue with a great mcveAn issue with a great mcveorm

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.