Description
🐛 Describe the bug
I'm attempting to use the FSDP2 API to shard a model, extract its state dictionary (for potential future use), and then completely remove the model from memory. Extracting the state dict somehow causes there to remain references to the underlying model around, and there ends up being a memory leak. Below i'll reuse the test here to demonstrate the issue.
When I add the step of using get_model_state_dict to extract the state dictionary (marked by DIFF STARTS HERE
below) the model continues to occupy memory even after both the model and state dictionary are explicitly deleted. This differs from the behavior in the original test, where memory is properly released.
This functionality is important especially in cases where we'd like to iteratively load a model, perform computation, offload it to cpu, then reload it when it's necessary. If this procedure is repeated, it blows up the GPU memory.
Below is the code snippet to reproduce the behavior, you will see that the test fail as it is, but will not fail if you simply comment out the part that goes with DIFF STARTS HERE
.
import gc
import torch
from torch.distributed.fsdp import fully_shard
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)
import os
import torch
import gc
from torch.distributed import init_process_group
from datetime import timedelta
from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
class TestFullyShardMemory(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
def _get_peak_active_memory_mb(self) -> int:
mem_stats = torch.cuda.memory_stats()
return round(mem_stats["active_bytes.all.peak"] / 1e6)
def _get_curr_active_memory_mb(self) -> int:
mem_stats = torch.cuda.memory_stats()
return round(mem_stats["active_bytes.all.current"] / 1e6)
def test_fully_shard_del_memory(self):
base_mem_mb = self._get_peak_active_memory_mb()
vocab_size = 32
model_args = ModelArgs(
vocab_size=vocab_size, n_layers=3, dim=768, n_heads=12, weight_tying=False
)
model = Transformer(model_args)
# Initializing the model on CPU should not change the GPU memory usage
post_model_init_mem_mb = self._get_peak_active_memory_mb()
self.assertEqual(base_mem_mb, post_model_init_mem_mb)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
unsharded_numel = sum(p.numel() for p in model.parameters())
sharded_numel = unsharded_numel // self.world_size
buffer_mb = 4
mem_mb = self._get_curr_active_memory_mb()
expected_mb = sharded_numel * 4 / 1e6 + buffer_mb
self.assertLessEqual(mem_mb - base_mem_mb, expected_mb)
### DIFF STARTS HERE ###
sdo = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
state_dict = get_model_state_dict(model, options=sdo)
del state_dict
### DIFF ENDS HERE ###
# Deleting the model should free all of the FSDP-managed GPU memory
del model
# Manually call garbage collection since there are ref cycles in FSDP
gc.collect()
torch.cuda.empty_cache()
mem_mb = self._get_curr_active_memory_mb()
print(f"Mem MB: {mem_mb}")
print(f"Base Mem MB: {base_mem_mb}")
self.assertEqual(mem_mb, base_mem_mb)
if __name__ == "__main__":
init_process_group(backend="nccl", timeout=timedelta(hours=24))
dst_rank = int(os.environ['RANK'])
dst_local_rank = int(os.environ['LOCAL_RANK'])
dst_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{dst_local_rank}'
run_tests()
Versions
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-163-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB
Nvidia driver version: 535.54.03
cuDNN version: Probably one of the following:
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn.so.8.9.2
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.2
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.2
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.2
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.2
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.2
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 128
On-line CPU(s) list: 0-127
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 2
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7543 32-Core Processor
Stepping: 1
Frequency boost: enabled
CPU MHz: 1499.953
CPU max MHz: 2800.0000
CPU min MHz: 1500.0000
BogoMIPS: 5600.18
Virtualization: AMD-V
L1d cache: 2 MiB
L1i cache: 2 MiB
L2 cache: 32 MiB
L3 cache: 512 MiB
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] torch==2.6.0
[pip3] torchaudio==2.5.1
[pip3] triton==3.2.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
[conda] torch 2.6.0 pypi_0 pypi
[conda] torchaudio 2.5.1 pypi_0 pypi
[conda] triton 3.2.0 pypi_0 pypi
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @kwen2501 @c-p-i-o