1818
1919from monai .utils import optional_import
2020from monai .utils .enums import StrEnum
21- from huggingface_hub import hf_hub_download
2221
2322LPIPS , _ = optional_import ("lpips" , name = "LPIPS" )
2423torchvision , _ = optional_import ("torchvision" )
2524
2625
2726class PercetualNetworkType (StrEnum ):
28- """Types of neural networks that are supported by perceptua loss.
29- """
27+ """Types of neural networks that are supported by perceptua loss."""
3028
3129 alex = "alex"
3230 vgg = "vgg"
@@ -116,8 +114,7 @@ def __init__(
116114 # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
117115 if spatial_dims == 3 and is_fake_3d is False :
118116 self .perceptual_function = MedicalNetPerceptualSimilarity (
119- net = network_type , verbose = False , channel_wise = channel_wise ,
120- cache_dir = cache_dir
117+ net = network_type , verbose = False , channel_wise = channel_wise , cache_dir = cache_dir
121118 )
122119 elif "radimagenet_" in network_type :
123120 self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False )
@@ -214,12 +211,17 @@ class MedicalNetPerceptualSimilarity(nn.Module):
214211 """
215212
216213 def __init__ (
217- self , net : str = "medicalnet_resnet_10_23datasets" , verbose : bool = False , channel_wise : bool = False ,
214+ self ,
215+ net : str = "medicalnet_resnet_10_23datasets" ,
216+ verbose : bool = False ,
217+ channel_wise : bool = False ,
218218 cache_dir : str | None = None ,
219219 ) -> None :
220220 super ().__init__ ()
221221 torch .hub ._validate_not_a_forked_repo = lambda a , b , c : True
222- self .model = torch .hub .load ("Project-MONAI/perceptual-models:main" , model = net , verbose = verbose , cache_dir = cache_dir )
222+ self .model = torch .hub .load (
223+ "Project-MONAI/perceptual-models:main" , model = net , verbose = verbose , cache_dir = cache_dir
224+ )
223225 self .eval ()
224226
225227 self .channel_wise = channel_wise
@@ -305,12 +307,9 @@ class RadImageNetPerceptualSimilarity(nn.Module):
305307 verbose: if false, mute messages from torch Hub load function.
306308 """
307309
308- def __init__ (self , net : str = "radimagenet_resnet50" ,
309- verbose : bool = False ,
310- cache_dir : str | None = None ) -> None :
310+ def __init__ (self , net : str = "radimagenet_resnet50" , verbose : bool = False , cache_dir : str | None = None ) -> None :
311311 super ().__init__ ()
312- self .model = torch .hub .load ("Project-MONAI/perceptual-models" , model = net , verbose = verbose ,
313- cache_dir = cache_dir )
312+ self .model = torch .hub .load ("Project-MONAI/perceptual-models" , model = net , verbose = verbose , cache_dir = cache_dir )
314313 self .eval ()
315314
316315 for param in self .parameters ():
0 commit comments