diff --git a/.github/workflows/checkpoint_converter.yml b/.github/workflows/checkpoint_converter.yml index 52bd004602b..abd709f5b9d 100644 --- a/.github/workflows/checkpoint_converter.yml +++ b/.github/workflows/checkpoint_converter.yml @@ -63,11 +63,11 @@ jobs: - name: Running Huggingface to Megatron dist_ckpt converter (Qwen/Qwen2.5-0.5B) run: | ray stop --force - python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/Qwen/Qwen2.5-0.5B + python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/Qwen/Qwen2.5-0.5B --test - name: Running Huggingface to Megatron dist_ckpt converter (deepseek-ai/deepseek-coder-1.3b-instruct) run: | ray stop --force - python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/deepseek-ai/deepseek-coder-1.3b-instruct --output_path checkpoints/deepseek-ai/deepseek-coder-1.3b-instruct + python scripts/converter_hf_to_mcore.py --hf_model_path=${HOME}/models/deepseek-ai/deepseek-coder-1.3b-instruct --output_path checkpoints/deepseek-ai/deepseek-coder-1.3b-instruct --test - name: Clean up run: | rm -rf checkpoints diff --git a/scripts/converter_hf_to_mcore.py b/scripts/converter_hf_to_mcore.py index d363de26f3a..fa134856b50 100644 --- a/scripts/converter_hf_to_mcore.py +++ b/scripts/converter_hf_to_mcore.py @@ -23,6 +23,7 @@ from megatron.core.dist_checkpointing.serialization import StrictHandling from megatron.core.models.gpt.gpt_model import ModelType from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.dist_checkpointing.mapping import ShardedTensor from transformers import AutoConfig from verl.models.mcore import hf_to_mcore_config @@ -74,8 +75,12 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): continue dut_data = dut_state_dict[name].data if name in ref_state_dict: - ref_data = ref_state_dict[name].data - assert dut_data.shape == ref_state_dict.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" + ref_data = ref_state_dict[name] + if isinstance(ref_data, ShardedTensor): + ref_data = ref_data.data.view(ref_data.local_shape) + else: + ref_data = ref_data.data + assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}" assert (dut_data == ref_data).all(), f"{name} is not equal" print(f"{name} is equal") else: @@ -84,7 +89,11 @@ def test_conversion(megatron_model_provider, tfconfig, output_path, model): if ref_state_dict[name] is None: print(f"[Warning] {name} is none in ref_state_dict") continue - ref_data = ref_state_dict[name].data + ref_data = ref_state_dict[name] + if isinstance(ref_data, ShardedTensor): + ref_data = ref_data.data.view(ref_data.local_shape) + else: + ref_data = ref_data.data if name in dut_state_dict: dut_data = dut_state_dict[name].data assert dut_data.shape == ref_data.shape, f"{name=} {dut_data.shape=} {ref_data.shape=}"