Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d948baa

Browse filesBrowse files
Add support for pytorch 2.6+ (#1956)
Fixes # . ### Description Add support for pytorch 2.6+ by fix `torch.load` deprecated error ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0f21835 commit d948baa
Copy full SHA for d948baa

File tree

100 files changed

+702
-782
lines changed
Filter options

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Dismiss banner

100 files changed

+702
-782
lines changed

‎2d_classification/mednist_tutorial.ipynb

Copy file name to clipboardExpand all lines: 2d_classification/mednist_tutorial.ipynb
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@
575575
"metadata": {},
576576
"outputs": [],
577577
"source": [
578-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
578+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
579579
"model.eval()\n",
580580
"y_true = []\n",
581581
"y_pred = []\n",

‎2d_segmentation/torch/unet_evaluation_array.py

Copy file name to clipboardExpand all lines: 2d_segmentation/torch/unet_evaluation_array.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def main(tempdir):
5858
num_res_units=2,
5959
).to(device)
6060

61-
model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth"))
61+
model.load_state_dict(torch.load("best_metric_model_segmentation2d_array.pth", weights_only=True))
6262
model.eval()
6363
with torch.no_grad():
6464
for val_data in val_loader:

‎2d_segmentation/torch/unet_evaluation_dict.py

Copy file name to clipboardExpand all lines: 2d_segmentation/torch/unet_evaluation_dict.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def main(tempdir):
7272
num_res_units=2,
7373
).to(device)
7474

75-
model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth"))
75+
model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict.pth", weights_only=True))
7676

7777
model.eval()
7878
with torch.no_grad():

‎3d_classification/torch/densenet_evaluation_array.py

Copy file name to clipboardExpand all lines: 3d_classification/torch/densenet_evaluation_array.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def main():
5757
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5858
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
5959

60-
model.load_state_dict(torch.load("best_metric_model_classification3d_array.pth"))
60+
model.load_state_dict(torch.load("best_metric_model_classification3d_array.pth", weights_only=True))
6161
model.eval()
6262
with torch.no_grad():
6363
num_correct = 0.0

‎3d_classification/torch/densenet_evaluation_dict.py

Copy file name to clipboardExpand all lines: 3d_classification/torch/densenet_evaluation_dict.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main():
6363
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6464
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
6565

66-
model.load_state_dict(torch.load("best_metric_model_classification3d_dict.pth"))
66+
model.load_state_dict(torch.load("best_metric_model_classification3d_dict.pth", weights_only=True))
6767
model.eval()
6868
with torch.no_grad():
6969
num_correct = 0.0

‎3d_registration/learn2reg_nlst_paired_lung_ct.ipynb

Copy file name to clipboardExpand all lines: 3d_registration/learn2reg_nlst_paired_lung_ct.ipynb
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@
872872
"source": [
873873
"# Automatic mixed precision (AMP) for faster training\n",
874874
"amp_enabled = True\n",
875-
"scaler = torch.cuda.amp.GradScaler()\n",
875+
"scaler = torch.GradScaler(\"cuda\")\n",
876876
"\n",
877877
"# Tensorboard\n",
878878
"if do_save:\n",
@@ -1127,7 +1127,7 @@
11271127
" )\n",
11281128
" # load model weights\n",
11291129
" filename_best_model = glob.glob(os.path.join(dir_load, \"segresnet_kpt_loss_best_tre*\"))[0]\n",
1130-
" model.load_state_dict(torch.load(filename_best_model))\n",
1130+
" model.load_state_dict(torch.load(filename_best_model, weights_only=True))\n",
11311131
" # to GPU\n",
11321132
" model.to(device)\n",
11331133
"\n",
@@ -1139,7 +1139,7 @@
11391139
"# Forward pass\n",
11401140
"model.eval()\n",
11411141
"with torch.no_grad():\n",
1142-
" with torch.cuda.amp.autocast(enabled=amp_enabled):\n",
1142+
" with torch.autocast(\"cuda\", enabled=amp_enabled):\n",
11431143
" ddf_image, ddf_keypoints, pred_image, pred_label = forward(\n",
11441144
" check_data[\"fixed_image\"].to(device),\n",
11451145
" check_data[\"moving_image\"].to(device),\n",

‎3d_registration/learn2reg_oasis_unpaired_brain_mr.ipynb

Copy file name to clipboardExpand all lines: 3d_registration/learn2reg_oasis_unpaired_brain_mr.ipynb
+5-5Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@
610610
"source": [
611611
"# Automatic mixed precision (AMP) for faster training\n",
612612
"amp_enabled = True\n",
613-
"scaler = torch.cuda.amp.GradScaler()\n",
613+
"scaler = torch.GradScaler(\"cuda\")\n",
614614
"\n",
615615
"# Tensorboard\n",
616616
"if do_save:\n",
@@ -646,7 +646,7 @@
646646
"\n",
647647
" # Forward pass and loss\n",
648648
" optimizer.zero_grad()\n",
649-
" with torch.cuda.amp.autocast(enabled=amp_enabled):\n",
649+
" with torch.autocast(\"cuda\", enabled=amp_enabled):\n",
650650
" ddf_image, pred_image, pred_label_one_hot = forward(\n",
651651
" fixed_image, moving_image, moving_label, model, warp_layer, num_classes=4\n",
652652
" )\n",
@@ -694,7 +694,7 @@
694694
" # moving_label_35 = batch_data[\"moving_label_35\"].to(device)\n",
695695
" n_steps += 1\n",
696696
" # Infer\n",
697-
" with torch.cuda.amp.autocast(enabled=amp_enabled):\n",
697+
" with torch.autocast(\"cuda\", enabled=amp_enabled):\n",
698698
" ddf_image, pred_image, pred_label_one_hot = forward(\n",
699699
" fixed_image, moving_image, moving_label_4, model, warp_layer, num_classes=4\n",
700700
" )\n",
@@ -840,7 +840,7 @@
840840
" model = VoxelMorph()\n",
841841
" # load model weights\n",
842842
" filename_best_model = glob.glob(os.path.join(dir_load, \"voxelmorph_loss_best_dice_*\"))[0]\n",
843-
" model.load_state_dict(torch.load(filename_best_model))\n",
843+
" model.load_state_dict(torch.load(filename_best_model, weights_only=True))\n",
844844
" # to GPU\n",
845845
" model.to(device)\n",
846846
"\n",
@@ -860,7 +860,7 @@
860860
"# Forward pass\n",
861861
"model.eval()\n",
862862
"with torch.no_grad():\n",
863-
" with torch.cuda.amp.autocast(enabled=amp_enabled):\n",
863+
" with torch.autocast(\"cuda\", enabled=amp_enabled):\n",
864864
" ddf_image, pred_image, pred_label_one_hot = forward(\n",
865865
" fixed_image, moving_image, moving_label_35, model, warp_layer, num_classes=35\n",
866866
" )"

‎3d_registration/paired_lung_ct.ipynb

Copy file name to clipboardExpand all lines: 3d_registration/paired_lung_ct.ipynb
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@
860860
"resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/pair_lung_ct.pth\"\n",
861861
"dst = f\"{root_dir}/pretrained_weight.pth\"\n",
862862
"download_url(resource, dst)\n",
863-
"model.load_state_dict(torch.load(dst))"
863+
"model.load_state_dict(torch.load(dst, weights_only=True))"
864864
]
865865
},
866866
{

‎3d_segmentation/brats_segmentation_3d.ipynb

Copy file name to clipboardExpand all lines: 3d_segmentation/brats_segmentation_3d.ipynb
+7-7Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,14 +473,14 @@
473473
" )\n",
474474
"\n",
475475
" if VAL_AMP:\n",
476-
" with torch.cuda.amp.autocast():\n",
476+
" with torch.autocast(\"cuda\"):\n",
477477
" return _compute(input)\n",
478478
" else:\n",
479479
" return _compute(input)\n",
480480
"\n",
481481
"\n",
482482
"# use amp to accelerate training\n",
483-
"scaler = torch.cuda.amp.GradScaler()\n",
483+
"scaler = torch.GradScaler(\"cuda\")\n",
484484
"# enable cuDNN benchmark\n",
485485
"torch.backends.cudnn.benchmark = True"
486486
]
@@ -526,7 +526,7 @@
526526
" batch_data[\"label\"].to(device),\n",
527527
" )\n",
528528
" optimizer.zero_grad()\n",
529-
" with torch.cuda.amp.autocast():\n",
529+
" with torch.autocast(\"cuda\"):\n",
530530
" outputs = model(inputs)\n",
531531
" loss = loss_function(outputs, labels)\n",
532532
" scaler.scale(loss).backward()\n",
@@ -733,7 +733,7 @@
733733
}
734734
],
735735
"source": [
736-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
736+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
737737
"model.eval()\n",
738738
"with torch.no_grad():\n",
739739
" # select one image to evaluate and visualize the model output\n",
@@ -835,7 +835,7 @@
835835
}
836836
],
837837
"source": [
838-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
838+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
839839
"model.eval()\n",
840840
"\n",
841841
"with torch.no_grad():\n",
@@ -924,7 +924,7 @@
924924
" )\n",
925925
"\n",
926926
" if VAL_AMP:\n",
927-
" with torch.cuda.amp.autocast():\n",
927+
" with torch.autocast(\"cuda\"):\n",
928928
" return _compute(input)\n",
929929
" else:\n",
930930
" return _compute(input)"
@@ -977,7 +977,7 @@
977977
"source": [
978978
"onnx_model_path = os.path.join(root_dir, \"best_metric_model.onnx\")\n",
979979
"ort_session = onnxruntime.InferenceSession(onnx_model_path)\n",
980-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
980+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
981981
"model.eval()\n",
982982
"\n",
983983
"with torch.no_grad():\n",

‎3d_segmentation/challenge_baseline/run_net.py

Copy file name to clipboardExpand all lines: 3d_segmentation/challenge_baseline/run_net.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def infer(data_folder=".", model_folder="runs", prediction_folder="output"):
219219

220220
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
221221
net = get_net().to(device)
222-
net.load_state_dict(torch.load(ckpt, map_location=device))
222+
net.load_state_dict(torch.load(ckpt, map_location=device, weights_only=True))
223223
net.eval()
224224

225225
image_folder = os.path.abspath(data_folder)

‎3d_segmentation/spleen_segmentation_3d.ipynb

Copy file name to clipboardExpand all lines: 3d_segmentation/spleen_segmentation_3d.ipynb
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@
640640
}
641641
],
642642
"source": [
643-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
643+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
644644
"model.eval()\n",
645645
"with torch.no_grad():\n",
646646
" for i, val_data in enumerate(val_loader):\n",
@@ -730,7 +730,7 @@
730730
}
731731
],
732732
"source": [
733-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
733+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
734734
"model.eval()\n",
735735
"\n",
736736
"with torch.no_grad():\n",
@@ -827,7 +827,7 @@
827827
"metadata": {},
828828
"outputs": [],
829829
"source": [
830-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
830+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
831831
"model.eval()\n",
832832
"\n",
833833
"with torch.no_grad():\n",

‎3d_segmentation/spleen_segmentation_3d_visualization_basic.ipynb

Copy file name to clipboardExpand all lines: 3d_segmentation/spleen_segmentation_3d_visualization_basic.ipynb
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@
823823
}
824824
],
825825
"source": [
826-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
826+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
827827
"model.eval()\n",
828828
"\n",
829829
"with torch.no_grad():\n",

‎3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb

Copy file name to clipboardExpand all lines: 3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@
885885
"metadata": {},
886886
"outputs": [],
887887
"source": [
888-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"model.pt\"))[\"state_dict\"])\n",
888+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"model.pt\"), weights_only=True)[\"state_dict\"])\n",
889889
"model.to(device)\n",
890890
"model.eval()\n",
891891
"\n",

‎3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb

Copy file name to clipboardExpand all lines: 3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb
+6-6Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@
472472
"metadata": {},
473473
"outputs": [],
474474
"source": [
475-
"weight = torch.load(\"./model_swinvit.pt\")\n",
475+
"weight = torch.load(\"./model_swinvit.pt\", weights_only=True)\n",
476476
"model.load_from(weights=weight)\n",
477477
"print(\"Using pretrained self-supervied Swin UNETR backbone weights !\")"
478478
]
@@ -493,7 +493,7 @@
493493
"torch.backends.cudnn.benchmark = True\n",
494494
"loss_function = DiceCELoss(to_onehot_y=True, softmax=True)\n",
495495
"optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)\n",
496-
"scaler = torch.cuda.amp.GradScaler()"
496+
"scaler = torch.GradScaler(\"cuda\")"
497497
]
498498
},
499499
{
@@ -516,7 +516,7 @@
516516
" with torch.no_grad():\n",
517517
" for batch in epoch_iterator_val:\n",
518518
" val_inputs, val_labels = (batch[\"image\"].cuda(), batch[\"label\"].cuda())\n",
519-
" with torch.cuda.amp.autocast():\n",
519+
" with torch.autocast(\"cuda\"):\n",
520520
" val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)\n",
521521
" val_labels_list = decollate_batch(val_labels)\n",
522522
" val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]\n",
@@ -537,7 +537,7 @@
537537
" for step, batch in enumerate(epoch_iterator):\n",
538538
" step += 1\n",
539539
" x, y = (batch[\"image\"].cuda(), batch[\"label\"].cuda())\n",
540-
" with torch.cuda.amp.autocast():\n",
540+
" with torch.autocast(\"cuda\"):\n",
541541
" logit_map = model(x)\n",
542542
" loss = loss_function(logit_map, y)\n",
543543
" scaler.scale(loss).backward()\n",
@@ -590,7 +590,7 @@
590590
"metric_values = []\n",
591591
"while global_step < max_iterations:\n",
592592
" global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)\n",
593-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))"
593+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))"
594594
]
595595
},
596596
{
@@ -679,7 +679,7 @@
679679
],
680680
"source": [
681681
"case_num = 4\n",
682-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
682+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
683683
"model.eval()\n",
684684
"with torch.no_grad():\n",
685685
" img_name = os.path.split(val_ds[case_num][\"image\"].meta[\"filename_or_obj\"])[1]\n",

‎3d_segmentation/torch/unet_evaluation_array.py

Copy file name to clipboardExpand all lines: 3d_segmentation/torch/unet_evaluation_array.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main(tempdir):
6363
num_res_units=2,
6464
).to(device)
6565

66-
model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth"))
66+
model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth", weights_only=True))
6767
model.eval()
6868
with torch.no_grad():
6969
for val_data in val_loader:

‎3d_segmentation/torch/unet_evaluation_dict.py

Copy file name to clipboardExpand all lines: 3d_segmentation/torch/unet_evaluation_dict.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def main(tempdir):
8181
num_res_units=2,
8282
).to(devices[0])
8383

84-
model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth"))
84+
model.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth", weights_only=True))
8585

8686
# if we have multiple GPUs, set data parallel to execute sliding window inference
8787
if len(devices) > 1:

‎3d_segmentation/torch/unet_inference_dict.py

Copy file name to clipboardExpand all lines: 3d_segmentation/torch/unet_inference_dict.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def main(tempdir):
9191
strides=(2, 2, 2, 2),
9292
num_res_units=2,
9393
).to(device)
94-
net.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth"))
94+
net.load_state_dict(torch.load("best_metric_model_segmentation3d_dict.pth", weights_only=True))
9595

9696
net.eval()
9797
with torch.no_grad():

‎3d_segmentation/unetr_btcv_segmentation_3d.ipynb

Copy file name to clipboardExpand all lines: 3d_segmentation/unetr_btcv_segmentation_3d.ipynb
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@
680680
"metric_values = []\n",
681681
"while global_step < max_iterations:\n",
682682
" global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)\n",
683-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))"
683+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))"
684684
]
685685
},
686686
{
@@ -769,7 +769,7 @@
769769
],
770770
"source": [
771771
"case_num = 4\n",
772-
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\")))\n",
772+
"model.load_state_dict(torch.load(os.path.join(root_dir, \"best_metric_model.pth\"), weights_only=True))\n",
773773
"model.eval()\n",
774774
"with torch.no_grad():\n",
775775
" img_name = os.path.split(val_ds[case_num][\"image\"].meta[\"filename_or_obj\"])[1]\n",

‎acceleration/TensorRT_inference_acceleration.ipynb

Copy file name to clipboardExpand all lines: acceleration/TensorRT_inference_acceleration.ipynb
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@
284284
"device = workflow.device\n",
285285
"spatial_shape = (1, 3, 736, 480)\n",
286286
"model = workflow.network_def\n",
287-
"model.load_state_dict(torch.load(model_weight))\n",
287+
"model.load_state_dict(torch.load(model_weight, weights_only=True))\n",
288288
"model.to(device)\n",
289289
"model.eval()\n",
290290
"\n",

‎acceleration/automatic_mixed_precision.ipynb

Copy file name to clipboardExpand all lines: acceleration/automatic_mixed_precision.ipynb
+3-3Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@
289289
" ).to(device)\n",
290290
" loss_function = DiceLoss(to_onehot_y=True, softmax=True)\n",
291291
" optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n",
292-
" scaler = torch.cuda.amp.GradScaler() if amp else None\n",
292+
" scaler = torch.GradScaler(\"cuda\") if amp else None\n",
293293
"\n",
294294
" post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])\n",
295295
" post_label = Compose([AsDiscrete(to_onehot=2)])\n",
@@ -321,7 +321,7 @@
321321
" )\n",
322322
" optimizer.zero_grad()\n",
323323
" if amp and scaler is not None:\n",
324-
" with torch.cuda.amp.autocast():\n",
324+
" with torch.autocast(\"cuda\"):\n",
325325
" outputs = model(inputs)\n",
326326
" loss = loss_function(outputs, labels)\n",
327327
" scaler.scale(loss).backward()\n",
@@ -353,7 +353,7 @@
353353
" roi_size = (160, 160, 128)\n",
354354
" sw_batch_size = 4\n",
355355
" if amp:\n",
356-
" with torch.cuda.amp.autocast():\n",
356+
" with torch.autocast(\"cuda\"):\n",
357357
" val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)\n",
358358
" else:\n",
359359
" val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)\n",

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.