Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Dev-next-gen/diffusers-rocm-parallel

Open more actions menu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

diffusers-rocm-parallel

Multi-GPU tensor / context parallel diffusion on AMD ROCm — with the patch that makes it actually work.

Companion repo: For the single-GPU AMD stack (5 torchao + diffusers backport patches that bring FLUX.1-dev to NVIDIA-class latency on one RX 7800 XT at 72.5 s / 6.4 GB), see flux-amd-rocm. This repo is the multi-GPU extension: true Megatron-style tensor parallelism (QKV + FFN sharded) AND context parallelism (ring attention / Ulysses) on AMD ROCm.

Diffusers 0.37 introduced native context parallelism (ring attention, Ulysses) via model.enable_parallelism(). It works out of the box on CUDA. On AMD ROCm (torch 2.9+), it crashes on the first denoising step with:

RuntimeError: The size of tensor a (24) must match the size of tensor b (128)
at non-singleton dimension 3

This repo is the fix, the benches that prove it, and a set of plug-and-play launchers for common multi-GPU diffusion workloads on AMD cards.


TL;DR

# 1. Install ROCm PyTorch (must be from the AMD wheel index, not PyPI)
python3 -m venv ~/rocm-venv && source ~/rocm-venv/bin/activate
pip install --pre torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm7.1

# 2. Clone + install deps
git clone https://github.com/Dev-next-gen/diffusers-rocm-parallel
cd diffusers-rocm-parallel
pip install -r requirements.txt

# 3. Run (set PYTHON so the launchers use your ROCm venv)
export PYTHON=~/rocm-venv/bin/python3

# 4× RX 7800 XT, FLUX.1-dev bf16, Megatron-style tensor parallelism
./examples/4gpu_flux_tp.sh

# 2× RX 7800 XT, FLUX.1-dev, ring attention
./examples/2gpu_flux_ring.sh

You get:

  • 4-GPU tensor parallelism: FLUX.1-dev bf16 in ~51 s, 11.18 GB per GPU (all 4 active simultaneously) — no quantization, full bf16 precision
  • Ring attention: same single-GPU VRAM envelope (6.4 GB) spread across N cards
  • No xfuser, no custom transformer wrappers beyond the sharding logic
  • Compatible with any FLUX.1-dev fine-tune (same architecture)

QUICKSTART.md for step-by-step reproduction


The bug

Diffusers' _templated_context_parallel_attention ring merge step expects the attention log-sum-exp (LSE) tensor to be 4-dimensional — shape [B, H, S, 1]. Older torch versions returned LSE as 3D, so diffusers has:

# diffusers/models/attention_dispatch.py
if is_torch_version("<", "2.9.0"):
    lse = lse.unsqueeze(-1)

The assumption is that on torch 2.9+, native SDPA already returns LSE as 4D. This is true on CUDA. It is NOT true on ROCmtorch.ops.aten._scaled_dot_product_flash_attention on ROCm 7.1 / AOTriton still returns LSE as [B, H, S]. Without the unsqueeze, the ring merge:

out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)

tries to broadcast a 3D LSE against 4D out, and fails.

The patch wraps _native_flash_attention_forward_op so that LSE is always 4D, regardless of backend. See docs/bug.md for the full write-up and reproducer.


Benchmarks

Measured on RX 7800 XT (gfx1101, 16 GB), ROCm 7.1, torch 2.9.1, diffusers 0.37.1. FLUX.1-dev 1024² × 28 steps.

Tensor Parallelism (Megatron-style QKV + FFN sharding) — bf16, no quantization

Config Latency Step Peak VRAM / GPU Total VRAM All GPUs active?
1× 7800 XT, single GPU bf16 ~144 s ~5.1 s ~24 GB 24 GB
4× 7800 XT, tp=4 bf16 51.5 s 1.84 s 11.18 GB 44.74 GB ✅ yes

How it works: QKV projections are column-parallel (each rank holds 6/24 heads), FFN projections are column/row-parallel. All 4 ranks run the full forward pass simultaneously; dist.all_reduce at each RowwiseLinear synchronises partial sums. AdaLN norm linears are replicated (small, output must be full-dim on every rank). See bench/flux_tensor_parallel.py for the full implementation.

Key constraint: FLUX.1-dev has 24 attention heads and inner_dim=3072 (24 × 128). Both are exactly divisible by 4 (6 heads/rank, 768 dims/rank) but NOT by 5 — tp=4 is the natural world size for this architecture.

Context Parallelism (ring attention) — int8 + group_offload

Config Latency Peak VRAM / GPU Total VRAM
1× 7800 XT baseline (reference, int8) 72.5 s 6.39 GB 6.39 GB
2× 7800 XT, ring_degree=2 102.9 s 6.39 GB 12.78 GB
4× 7800 XT, ring_degree=4 pending pending pending

Reading these numbers: ring attention with group_offload does NOT speed up 1024² generation on this hardware — PCIe KV-gather communication dominates. The win is that it works at all on AMD (previously impossible), and VRAM per GPU stays flat, so you could fit a larger model or resolution.


What's in this repo

File Purpose
bench/flux_tensor_parallel.py 4-GPU Megatron-style TP — FLUX.1-dev bf16, all ranks active simultaneously
bench/flux_ring_attention.py Ring attention bench — FLUX.1-dev + ring attention + group_offload
bench/flux_device_map_balanced.py Weight-sharded pipeline via device_map="balanced" (single process, sequential)
examples/_common.sh Shared ROCm env vars sourced by all launchers
examples/4gpu_flux_tp.sh Launcher for TP-4
examples/2gpu_flux_ring.sh Launcher for ring attention, 2 GPUs
examples/4gpu_flux_ring.sh Launcher for ring attention, 4 GPUs
examples/5gpu_flux_ring.sh Launcher for ring attention, 5 GPUs — 1008² (not yet validated)
examples/5gpu_flux_device_map.sh device_map="balanced" across 5 GPUs — sequential, not recommended for latency
patches/diffusers_rocm_lse_shape.py Monkey-patch fixing the LSE shape bug for ring attention on ROCm
QUICKSTART.md Step-by-step reproduction guide
docs/bug.md LSE shape bug write-up, reproducer, proposed upstream fix
docs/tp_bugs.md The 2 silent TP bugs found during development + root-cause + fix
docs/performance.md When TP / CP helps vs hurts on consumer AMD
tests/test_lse_shape.py Regression test for the LSE patch

Requirements

Component Version
ROCm 7.1+
PyTorch 2.9.1+rocm7.1.1
diffusers 0.37+
torchao 0.13 – 0.14.1 (for int8 benches)
GPUs ≥2 RDNA3 (gfx1100 / gfx1101) or CDNA2/3

For the full torchao + group_offload stack on AMD (5 other patches), see the companion repo flux-amd-rocm.


Tensor Parallelism — how it works

The TP implementation in bench/flux_tensor_parallel.py is a from-scratch Megatron-style sharding applied as a post-load monkey-patch. No custom model class, no framework dependency beyond PyTorch distributed.

                   rank 0          rank 1          rank 2          rank 3
to_q/to_k/to_v    out[0:768]      out[768:1536]   out[1536:2304]  out[2304:3072]
to_out[0]         in[0:768]       in[768:1536]    in[1536:2304]   in[2304:3072]
                                     ← all_reduce →
ff.net[0].proj    out[0:3072]     out[3072:6144]  out[6144:9216]  out[9216:12288]
ff.net[2]         in[0:3072]      in[3072:6144]   in[6144:9216]   in[9216:12288]
                                     ← all_reduce →

out[a:b] = rank holds those output-dimension rows of the weight matrix (ColwiseParallel).
in[a:b] = rank holds those input-dimension columns of the weight matrix (RowwiseParallel), followed by dist.all_reduce.

  • ColwiseParallel (_Col): rank i holds output rows [i*s:(i+1)*s]; bias is also sliced. Output is sharded along the output dimension.
  • RowwiseParallel (_Row): rank i holds input columns [i*s:(i+1)*s]; dist.all_reduce after local matmul produces the full replicated output.
  • AdaLN (norm linears): replicated across all ranks — their output (shift/scale/gate) must be full-dim everywhere. Post-all_reduce activations are also full-dim, so the element-wise multiply works.
  • Head patching: attn.heads is set to 24 // 4 = 6 per rank so that unflatten(-1, (heads, -1)) gives the correct local shape (B, S, 6, 128).

The load sequence minimises peak VRAM: encode text (T5, 10 GB) on rank 0 alone → broadcast embeddings → free T5 → all ranks load transformer in parallel → apply TP → each rank retains only its ~11 GB shard.

Implementation bugs found (and fixed)

Two silent correctness bugs were discovered while implementing the TP sharding. Both cause the model to produce pure noise without any error or crash — the only diagnostic is visual inspection of the output.

Bug Root cause Symptom
TP-1 proj_out in single blocks slices contiguous columns instead of the correct non-contiguous [attn_cols | mlp_cols] split Garbage output from all 38 single-stream blocks
TP-2 dist.broadcast(tensor.cuda(), src=0) writes into a temporary — non-rank-0 GPUs keep timestep=0.0 for all denoising steps All non-rank-0 ranks compute with wrong time embedding → incoherent all_reduce

Full root-cause analysis and minimal reproducers: docs/tp_bugs.md


Upstream status

The LSE shape fix will be filed as a PR against huggingface/diffusers. The proper fix is to test lse.ndim < out.ndim (or the active backend), not the torch version — the torch version check conflates CUDA and ROCm backend behaviour.

Until then, this monkey-patch is drop-in.


License

MIT.

FLUX.1-dev weights are released by Black Forest Labs under their own non-commercial license. This repo does not redistribute any model weights.


Credits

  • @Sayak Paul and the HuggingFace / PyTorch / TorchAO teams for the upstream diffusers parallelism work
  • The ROCm and AOTriton teams at AMD
  • Leo Camus — Megatron-style TP implementation, LSE shape backport, AMD-specific patches and reference benchmarks

About

Multi-GPU tensor/context parallel diffusion on AMD ROCm — with the patch that makes it actually work.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

Morty Proxy This is a proxified and sanitized view of the page, visit original site.