This example shows how to use Dependency Injector
with FastAPI and
SQLAlchemy.
The source code is available on the Github.
Thanks to @ShvetsovYura for providing initial example: FastAPI_DI_SqlAlchemy.
Application has next structure:
./
├── webapp/
│ ├── __init__.py
│ ├── application.py
│ ├── containers.py
│ ├── database.py
│ ├── endpoints.py
│ ├── models.py
│ ├── repositories.py
│ ├── services.py
│ └── tests.py
├── config.yml
├── docker-compose.yml
├── Dockerfile
└── requirements.txt
Application factory creates container, wires it with the endpoints
module, creates
FastAPI
app, and setup routes.
Application factory also creates database if it does not exist.
Listing of webapp/application.py
:
"""Application module."""
from fastapi import FastAPI
from .containers import Container
from . import endpoints
def create_app() -> FastAPI:
container = Container()
db = container.db()
db.create_database()
app = FastAPI()
app.container = container
app.include_router(endpoints.router)
return app
app = create_app()
Module endpoints
contains example endpoints. Endpoints have a dependency on user service.
User service is injected using Wiring feature. See webapp/endpoints.py
:
"""Endpoints module."""
from typing import Annotated
from fastapi import APIRouter, Depends, Response, status
from dependency_injector.wiring import Provide, inject
from .containers import Container
from .repositories import NotFoundError
from .services import UserService
router = APIRouter()
@router.get("/users")
@inject
def get_list(
user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
):
return user_service.get_users()
@router.get("/users/{user_id}")
@inject
def get_by_id(
user_id: int,
user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
):
try:
return user_service.get_user_by_id(user_id)
except NotFoundError:
return Response(status_code=status.HTTP_404_NOT_FOUND)
@router.post("/users", status_code=status.HTTP_201_CREATED)
@inject
def add(
user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
):
return user_service.create_user()
@router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
@inject
def remove(
user_id: int,
user_service: Annotated[UserService, Depends(Provide[Container.user_service])],
) -> Response:
try:
user_service.delete_user_by_id(user_id)
except NotFoundError:
return Response(status_code=status.HTTP_404_NOT_FOUND)
else:
return Response(status_code=status.HTTP_204_NO_CONTENT)
@router.get("/status")
def get_status():
return {"status": "OK"}
Declarative container wires example user service, user repository, and utility database class.
See webapp/containers.py
:
"""Containers module."""
from dependency_injector import containers, providers
from .database import Database
from .repositories import UserRepository
from .services import UserService
class Container(containers.DeclarativeContainer):
wiring_config = containers.WiringConfiguration(modules=[".endpoints"])
config = providers.Configuration(yaml_files=["config.yml"])
db = providers.Singleton(Database, db_url=config.db.url)
user_repository = providers.Factory(
UserRepository,
session_factory=db.provided.session,
)
user_service = providers.Factory(
UserService,
user_repository=user_repository,
)
Module services
contains example user service. See webapp/services.py
:
"""Services module."""
from uuid import uuid4
from typing import Iterator
from .repositories import UserRepository
from .models import User
class UserService:
def __init__(self, user_repository: UserRepository) -> None:
self._repository: UserRepository = user_repository
def get_users(self) -> Iterator[User]:
return self._repository.get_all()
def get_user_by_id(self, user_id: int) -> User:
return self._repository.get_by_id(user_id)
def create_user(self) -> User:
uid = uuid4()
return self._repository.add(email=f"{uid}@email.com", password="pwd")
def delete_user_by_id(self, user_id: int) -> None:
return self._repository.delete_by_id(user_id)
Module repositories
contains example user repository. See webapp/repositories.py
:
"""Repositories module."""
from contextlib import AbstractContextManager
from typing import Callable, Iterator
from sqlalchemy.orm import Session
from .models import User
class UserRepository:
def __init__(self, session_factory: Callable[..., AbstractContextManager[Session]]) -> None:
self.session_factory = session_factory
def get_all(self) -> Iterator[User]:
with self.session_factory() as session:
return session.query(User).all()
def get_by_id(self, user_id: int) -> User:
with self.session_factory() as session:
user = session.query(User).filter(User.id == user_id).first()
if not user:
raise UserNotFoundError(user_id)
return user
def add(self, email: str, password: str, is_active: bool = True) -> User:
with self.session_factory() as session:
user = User(email=email, hashed_password=password, is_active=is_active)
session.add(user)
session.commit()
session.refresh(user)
return user
def delete_by_id(self, user_id: int) -> None:
with self.session_factory() as session:
entity: User = session.query(User).filter(User.id == user_id).first()
if not entity:
raise UserNotFoundError(user_id)
session.delete(entity)
session.commit()
class NotFoundError(Exception):
entity_name: str
def __init__(self, entity_id):
super().__init__(f"{self.entity_name} not found, id: {entity_id}")
class UserNotFoundError(NotFoundError):
entity_name: str = "User"
Module models
contains example SQLAlchemy user model. See webapp/models.py
:
"""Models module."""
from sqlalchemy import Column, String, Boolean, Integer
from .database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
email = Column(String, unique=True)
hashed_password = Column(String)
is_active = Column(Boolean, default=True)
def __repr__(self):
return f"<User(id={self.id}, " \
f"email=\"{self.email}\", " \
f"hashed_password=\"{self.hashed_password}\", " \
f"is_active={self.is_active})>"
Module database
defines declarative base and utility class with engine and session factory.
See webapp/database.py
:
"""Database module."""
from contextlib import contextmanager, AbstractContextManager
from typing import Callable
import logging
from sqlalchemy import create_engine, orm
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
Base = declarative_base()
class Database:
def __init__(self, db_url: str) -> None:
self._engine = create_engine(db_url, echo=True)
self._session_factory = orm.scoped_session(
orm.sessionmaker(
autocommit=False,
autoflush=False,
bind=self._engine,
),
)
def create_database(self) -> None:
Base.metadata.create_all(self._engine)
@contextmanager
def session(self) -> Callable[..., AbstractContextManager[Session]]:
session: Session = self._session_factory()
try:
yield session
except Exception:
logger.exception("Session rollback because of exception")
session.rollback()
raise
finally:
session.close()
Tests use Provider overriding feature to replace repository with a mock. See webapp/tests.py
:
"""Tests module."""
from unittest import mock
import pytest
from fastapi.testclient import TestClient
from .repositories import UserRepository, UserNotFoundError
from .models import User
from .application import app
@pytest.fixture
def client():
yield TestClient(app)
def test_get_list(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.get_all.return_value = [
User(id=1, email="test1@email.com", hashed_password="pwd", is_active=True),
User(id=2, email="test2@email.com", hashed_password="pwd", is_active=False),
]
with app.container.user_repository.override(repository_mock):
response = client.get("/users")
assert response.status_code == 200
data = response.json()
assert data == [
{"id": 1, "email": "test1@email.com", "hashed_password": "pwd", "is_active": True},
{"id": 2, "email": "test2@email.com", "hashed_password": "pwd", "is_active": False},
]
def test_get_by_id(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.get_by_id.return_value = User(
id=1,
email="xyz@email.com",
hashed_password="pwd",
is_active=True,
)
with app.container.user_repository.override(repository_mock):
response = client.get("/users/1")
assert response.status_code == 200
data = response.json()
assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True}
repository_mock.get_by_id.assert_called_once_with(1)
def test_get_by_id_404(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.get_by_id.side_effect = UserNotFoundError(1)
with app.container.user_repository.override(repository_mock):
response = client.get("/users/1")
assert response.status_code == 404
@mock.patch("webapp.services.uuid4", return_value="xyz")
def test_add(_, client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.add.return_value = User(
id=1,
email="xyz@email.com",
hashed_password="pwd",
is_active=True,
)
with app.container.user_repository.override(repository_mock):
response = client.post("/users")
assert response.status_code == 201
data = response.json()
assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True}
repository_mock.add.assert_called_once_with(email="xyz@email.com", password="pwd")
def test_remove(client):
repository_mock = mock.Mock(spec=UserRepository)
with app.container.user_repository.override(repository_mock):
response = client.delete("/users/1")
assert response.status_code == 204
repository_mock.delete_by_id.assert_called_once_with(1)
def test_remove_404(client):
repository_mock = mock.Mock(spec=UserRepository)
repository_mock.delete_by_id.side_effect = UserNotFoundError(1)
with app.container.user_repository.override(repository_mock):
response = client.delete("/users/1")
assert response.status_code == 404
def test_status(client):
response = client.get("/status")
assert response.status_code == 200
data = response.json()
assert data == {"status": "OK"}
The source code is available on the Github.
Sponsor the project on GitHub: |