- [06/2025] MegaFold code is released.
MegaFold is a cross-platform system to accelerate protein structure prediction models (e.g., AlphaFold3, AlphaFold2).
Why MegaFold?
- Cross-platform support: Supports execution on heterogeneous devices, including NVIDIA GPUs and AMD GPUs, through optimized Triton-based kernels.
- Ease of use: Delivers huge performance gains with few lines of code change
- Speed improvement: Accelerates per-iteration training time by up to 1.73x
- Memory reduction: Reduces peak memory during training by up to 1.23x
- Sequence length extension: Enables training on 1.35x longer sequence lengths
We include code for AlphaFold3 training with end-to-end MegaFold integrations and instructions to reproduce our paper results.
# create virtual environment under python==3.13 and activate
conda create -n venv python==3.13.0
conda activate venv
# install torch==2.7.0+cu11.8
pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cu118
# install other packages
pip install -r requirements.txt
First, download a sample dataset from the Protein Data Bank (PDB).
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/ESbEXPguyO9Moh3E_J1zkWQBXZ6JxE5bsoKrZXOVwtu1Ow?download=1" -O data/pdb_data/val_mmcifs.tar.gz
tar -xzf data/pdb_data/val_mmcifs.tar.gz -C data/pdb_data
rm data/pdb_data/val_mmcifs.tar.gz
Then, install required MSAs and templates data.
# install msa_dir
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/EbXU1bnlRZxIqUXbAprgHycB3F4GWLy-m-qxvODfJsvFvA?download=1" -O pdb_val_msas
tar -xvzf pdb_val_msas
cp -r scratch/references/af3/pdb_data/* data/pdb_data/
rm pdb_val_msas
rm -r scratch
# install templates_dir
wget "https://umass-my.sharepoint.com/:u:/g/personal/hvla_umass_edu/EUalS7Hq3KBOlGdF2bVVwFABYU_ZidT2nEEi0PwqxaZ_Fw?download=1" -O templates_dir
tar -xvzf templates_dir
cp -r scratch/references/af3/pdb_data/* data/pdb_data/
rm templates_dir
rm -r scratch
Then, install PDB's Chemical Component Dictionary (CCD) and miscellaneous metadata.
# install CCD data
wget -P ./data/ccd_data/ https://files.wwpdb.org/pub/pdb/data/monomers/components.cif.gz
wget -P ./data/ccd_data/ https://files.wwpdb.org/pub/pdb/data/component-models/complete/chem_comp_model.cif.gz
gunzip data/ccd_data/components.cif.gz
gunzip data/ccd_data/chem_comp_model.cif.gz
# install misc_data
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/ESb9kUT_ASBEsYRN0KQmqt4BLzJhFunQU86E-GxWGxtGiA?download=1" -O misc_data
tar -xzf misc_data -C data/pdb_data
rm misc_data
Now, download the cache of deterministic features, used in Ahead-of-Time Cache-based Data-Loading Optimization.
# install msa_cache_dir
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/Ect3VyxyqnZPm-4I6EpzB64B2M6tGctY5OMjIkatr6kYHQ?download=1" -O msa_cache
tar -xvzf msa_cache --wildcards 'caches/pdb_data/cache/msa/val_msas*'
rm msa_cache
# install input_cache_dir
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/EXQnFYxhepNNku_Df45B1gEBPlhzIH_RtnhUEae4b74SKQ?download=1" -O input_cache
tar -xvzf input_cache
rm input_cache
python3 train.py --config configs/megafold_interactive.yaml --trainer_name initial_training
Script to submit batch jobs is available in scripts
. For example, you want to launch a job with nodes=1
and gpus=2
:
sbatch --nodes=1 --ntasks-per-node=2 --gpus=2 scripts/megafold.sh
If you are interested in running large-scale AlphaFold3 training, the full dataset and its cache are provided below:
# download `omniflow_caches.tar.gz.part_{aa,ab}` and `omniflow_data.tar.gz` from SharePoint
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/Ect3VyxyqnZPm-4I6EpzB64B2M6tGctY5OMjIkatr6kYHQ?download=1"
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/ERiOg_fC_6BFnr9oKilzeeUBz8O_a2tI0i-TlksYAf8E5g?download=1"
wget "https://mailmissouri-my.sharepoint.com/:u:/g/personal/acmwhb_umsystem_edu/EYQ9oFu5KmFLryp8F1m79BAB2zoUFtLIU-Bx2OWmmKAdtA?download=1"
# then reassemble, extract, and clean up the downloaded archives
cat omniflow_caches.tar.gz.part_* > omniflow_caches.tar.gz
tar -xzf omniflow_caches.tar.gz && rm omniflow_caches.tar.gz
tar -xzf omniflow_data.tar.gz && rm omniflow_data.tar.gz
The following section gives detailed instructions on enabling each of our optimizations.
The file megafold/inputs.py
includes the data pipeline and implementation details for the ahead-of-time cache-based data loading optimizations.
You can find details on deterministic input features cache and on MSA features cache.
The folder megafold/model/FusedEvoAttention
includes source code of FusedEvoAttention kernel.
from megafold.model.FusedEvoAttention.evoattention import TritonEvoformer
FusedEvoAttention
supports 4 main types of EvoAttention in AlphaFold models, shown in the below examples. For accuracy, you need to adjust your inputs to their suggested shapes before passing in. Acronyms: N_seq
is the MSA depth; N_res
is the input sequence length.
a. Single Attention with Pair Bias
# Q, K, V: [Batch, 1, N_res, Head, Dim]
# mask: [Batch, 1, 1, 1, N_res]
# pair_bias: [Batch, 1, Head, N_res, N_res]
out = TritonEvoformer(Q, K, V, mask, pair_bias)
b. Triangle Attention (around starting node and around ending node)
# Q, K, V: [Batch, N_res, N_res, Head, Dim]
# mask: [Batch, N_res, 1, 1, N_res]
# pair_bias: [Batch, 1, Head, N_res, N_res]
out = TritonEvoformer(Q, K, V, mask, pair_bias)
c. MSA Row-wise Attention
# Q, K, V: [Batch, N_seq, N_res, Head, Dim]
# mask: [Batch, N_seq, 1, 1, N_res]
# pair_bias: [Batch, 1, Head, N_res, N_res]
out = TritonEvoformer(Q, K, V, mask, pair_bias)
d. MSA Column-wise Attention
# Q, K, V: [Batch, N_res, N_seq, Head, Dim]
# mask: [Batch, N_seq, 1, 1, N_res]
out = TritonEvoformer(Q, K, V, mask)
To achieve peak performance, the kernel's configuration (block sizes, num warps, etc.) should be tuned to your specific hardware and input shapes.
- Import
TritonEvoformer
frommegafold.model.FusedEvoAttention.unfused_evoattention
(starts with untuned kernels) - Use it in your model's training or inference script.
- Run your script with autotuning enabled:
TRITON_PRINT_AUTOTUNING=1 python your_script.py
- With autotuning enabled, Triton will explore multiple kernel configurations. Then, it will print the best configuration for your input.
- Let the script run for several training iterations. Take note of the most frequently selected configuration—it is likely the best one for your target hardware and input shapes (sequence length).
- Manually write in the best configurations for each JIT kernels and comment out the
@triton.autotune
decorator of each jit kernels. An example of an autotuned kernel for NVIDIA H200 and sequence length 384 is provided inmegafold.model.FusedEvoAttention.evoattention
. - Use the modified kernel in your real workloads for best performance.
The folder megafold/model/FusedLayernormLinear
includes source code of fused layernorm-linear kernel.
from megafold.model.FusedLayernormLinear.fused_layernorm_linear import LayernormLinear
FusedLayernormLinear fuses sequential LayerNorm
and Linear
layers. You can replace any such occurences with LayernormLinear
.
# init
- layernorm = LayerNorm(dim_K)
- linear = Linear(dim_K, dim_N)
+ fused_layernorm_linear = LayernormLinear(dim_K, dim_N)
# model pass
- layernorm_linear_out = linear(layernorm(input))
+ layernorm_linear_out = fused_layernorm_linear(input)
- NOTE:
LayernormLinear
relies on tuned configurations (block sizes, num warps, etc.), which we provide for AF3 inputs to the kernel inhelper.py
. If you intend to apply the kernel to other input shapes, you can perform the Autotuning step (similar toFusedEvoAttention
's Step 3) withuntuned_fused_layernorm_linear.py
The folder megafold/model/FusedTransition
includes source code of FusedTransition kernel.
from megafold.model.FusedTransition.fused_transition import FusedTransition
FusedTransition
fuses the AF3's Transition layer (original implementation in benchmarks/transition_speed.py
). You can replace the original Transition with FusedTransition
.
# init
- transition = Transition(dim=dim, expansion_factor=expansion_factor)
+ transition = FusedTransition(dim=dim, expansion_factor=expansion_factor)
- NOTE:
FusedTransition
relies on FusedLayernormLinear for its expanding projections. Make sure you read FusedLayernormLinear's usage guide above.
@misc{la2025megafoldsystemleveloptimizationsaccelerating,
title={MegaFold: System-Level Optimizations for Accelerating Protein Structure Prediction Models},
author={Hoa La and Ahan Gupta and Alex Morehead and Jianlin Cheng and Minjia Zhang},
year={2025},
eprint={2506.20686},
archivePrefix={arXiv},
primaryClass={q-bio.BM},
url={https://arxiv.org/abs/2506.20686},
}
- alphafold3-pytorch for the open-source code that MegaFold is built on top.
- AMD for the AMD platforms.