-
Notifications
You must be signed in to change notification settings - Fork 2.2k
switch to sleep level=2 and split wake-ups in GRPO and RLOO trainers #4296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! I was not able to profile because I think that vLLM spawns another process so it's not clear how to profile in this case, but otherwise it looks good to me.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
572bc55
to
e21b4f1
Compare
Thanks for the review! If helpful, here are the profiling references I used for vLLM’s multiprocessing setup and a small repro.
Profiling Commandexport VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_NVTX_SCOPES_FOR_PROFILING=1
nsys profile \
--wait all \
--stats true \
--capture-range cudaProfilerApi \
--capture-range-end stop \
--trace-fork-before-exec=true \
--cuda-graph-trace=node \
-o profile \
python vllm_sleep.py vllm_sleep.pyimport torch
from vllm import LLM, SamplingParams
from vllm.utils import GiB_bytes
import nvtx
# Reference:
# https://docs.vllm.ai/en/latest/features/sleep_mode.html#sleep-mode
# https://github.com/vllm-project/vllm/blob/main/tests/basic_correctness/test_cumem.py
def main():
torch.cuda.cudart().cudaProfilerStart()
model = "Qwen/Qwen2.5-Coder-7B-Instruct"
llm = LLM(model, enable_sleep_mode=True, tensor_parallel_size=2)
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
with nvtx.annotate("generate", color="green"):
output = llm.generate(prompt, sampling_params)
with nvtx.annotate("sleep", color="red"):
llm.sleep(level=2) # or level=1
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
assert used_bytes < 3 * GiB_bytes
with nvtx.annotate("wake_up(weights)", color="brown"):
llm.wake_up(tags=["weights"])
with nvtx.annotate("reload_weights", color="purple"):
llm.collective_rpc("reload_weights")
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes
with nvtx.annotate("wake_up(kv_cache)", color="brown"):
llm.wake_up(tags=["kv_cache"])
with nvtx.annotate("generate", color="green"):
output2 = llm.generate(prompt, sampling_params)
assert output[0].outputs[0].text == output2[0].outputs[0].text
torch.cuda.cudart().cudaProfilerStop()
if __name__ == "__main__":
main() Nsys Profiler OutputSleep level=1 ** CUDA GPU MemOps Summary (by Size) (cuda_gpu_mem_size_sum):
Total (MB) Count Avg (MB) Med (MB) Min (MB) Max (MB) StdDev (MB) Operation
---------- ----- -------- -------- -------- -------- ----------- ------------------------------
45,978.971 4,278 10.748 0.003 0.000 545.260 37.442 [CUDA memcpy Host-to-Device]
15,510.537 272 57.024 16.777 0.000 545.260 77.857 [CUDA memcpy Device-to-Host]
840.577 1,178 0.714 0.004 0.004 7.340 2.170 [CUDA memcpy Device-to-Device]
54.916 634 0.087 0.000 0.000 10.486 0.886 [CUDA memset] Sleep level=2 ** CUDA GPU MemOps Summary (by Size) (cuda_gpu_mem_size_sum):
Total (MB) Count Avg (MB) Med (MB) Min (MB) Max (MB) StdDev (MB) Operation
---------- ----- -------- -------- -------- -------- ----------- ------------------------------
30,485.212 4,048 7.531 0.002 0.000 544.997 30.373 [CUDA memcpy Host-to-Device]
840.577 1,178 0.714 0.004 0.004 7.340 2.170 [CUDA memcpy Device-to-Device]
54.916 634 0.087 0.000 0.000 10.486 0.886 [CUDA memset]
16.778 42 0.399 0.000 0.000 8.389 1.808 [CUDA memcpy Device-to-Host] H2D and D2H traffic drops by ~15.5 GB at sleep level=2 vs level=1, broadly in line with a 7B bf16 model being offloaded. |
What does this PR do?
vLLM’s Level-1 sleep mode for co-located deployments is integrated into TRL’s GRPO trainer via PR #3968 (#3968).
Since GRPO updates the model after every step, Level-2 sleep should be usable with GRPO.
See also split wake-ups: https://docs.vllm.ai/en/v0.10.2/features/sleep_mode.html#rlhf-weight-updates
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.