Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion 7 modelopt/torch/export/plugins/mcore_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ def save_safetensors_by_layer_index(
meta_filename = filename + ".json"
ckpt_filename = filename + ".safetensors"

# Write safetensors first, then build the per-layer meta JSON from the same dict.
# Order matters: any late mutations to layer_state_dict (e.g. MTP tensors added after
# the dict was first constructed) must be captured by both files. Writing safetensors
# first ensures the JSON is always consistent with what is physically on disk.
save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

weight_map = {}
layer_total_size = 0
for key, val in layer_state_dict.items():
Expand All @@ -318,7 +324,6 @@ def save_safetensors_by_layer_index(
f,
indent=4,
)
save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})

# [TODO]: this global barrier needs to be replaced with something safer
torch.distributed.barrier()
Expand Down
Morty Proxy This is a proxified and sanitized view of the page, visit original site.