2121from deepspeed .runtime .zero .partition_parameters import ZeroParamStatus
2222from deepspeed .runtime .zero .utils import is_zero_supported_optimizer
2323from deepspeed .runtime .activation_checkpointing import checkpointing as activation_checkpointing
24- from deepspeed .runtime .fp16 .fused_optimizer import FP16_Optimizer
24+ from deepspeed .runtime .fp16 .fused_optimizer import FP16_Optimizer , FP16_FUSED_SUPPORTED_OPTIMIZERS , is_fp16_fused_supported_optimizer
2525from deepspeed .runtime .fp16 .unfused_optimizer import FP16_UnfusedOptimizer
2626from deepspeed .runtime .config import DeepSpeedConfig , DEEPSPEED_OPTIMIZERS , \
2727 ADAM_OPTIMIZER , ADAMW_OPTIMIZER , LAMB_OPTIMIZER , ONEBIT_ADAM_OPTIMIZER , ONEBIT_LAMB_OPTIMIZER , \
@@ -397,9 +397,6 @@ def zero_gather_fp16_weights_on_model_save(self):
397397 def fp16_enabled (self ):
398398 return self ._config .fp16_enabled
399399
400- def precision (self ):
401- return self ._config .precision
402-
403400 def amp_enabled (self ):
404401 return self ._config .amp_enabled
405402
@@ -572,18 +569,14 @@ def is_replicated(p):
572569
573570 for p in self .module .parameters ():
574571 if torch .is_tensor (p ) and is_replicated (p ):
575- if self .precision () == torch .bfloat16 :
576- p = p .float ()
577572 dist .broadcast (p ,
578573 self .broadcast_src_rank ,
579574 group = self .data_parallel_group )
580- if self .precision () == torch .bfloat16 :
581- p = p .bfloat16 ()
582575
583576 def _configure_distributed_model (self , model ):
584577 self .module = model
585578 if self .fp16_enabled ():
586- self .module .to ( self . precision () )
579+ self .module .half ( )
587580
588581 if not self .dont_change_device :
589582 self .module .to (self .device )
@@ -721,8 +714,7 @@ def _configure_fp16_optimizer(self, optimizer):
721714 initial_dynamic_scale = self .initial_dynamic_scale ()
722715 dynamic_loss_args = self .dynamic_loss_scale_args ()
723716 clip_grad = self .gradient_clipping ()
724- if isinstance (optimizer ,
725- FusedAdam ) or self .optimizer_name () == ONEBIT_ADAM_OPTIMIZER :
717+ if is_fp16_fused_supported_optimizer (optimizer ):
726718 if self .dynamic_loss_scale ():
727719 log_dist ('Creating fp16 optimizer with dynamic loss scale' , ranks = [0 ])
728720 timers = self .timers if self .wall_clock_breakdown () else None
@@ -780,8 +772,7 @@ def _configure_zero_optimizer(self, optimizer):
780772 max_elements_per_comm = self .zero_reduce_bucket_size (),
781773 dp_process_group = self .data_parallel_group ,
782774 elastic_checkpoint = self .zero_elastic_checkpoint (),
783- mpu = self .mpu ,
784- precision = self .precision ())
775+ mpu = self .mpu )
785776 elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS :
786777 optimizer = FP16_DeepSpeedZeroOptimizer (
787778 optimizer ,
@@ -800,8 +791,7 @@ def _configure_zero_optimizer(self, optimizer):
800791 mpu = self .mpu ,
801792 postscale_gradients = self .postscale_gradients (),
802793 gradient_predivide_factor = self .gradient_predivide_factor (),
803- gradient_accumulation_steps = self .gradient_accumulation_steps (),
804- precision = self .precision ())
794+ gradient_accumulation_steps = self .gradient_accumulation_steps ())
805795 elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS :
806796 print ("Initializing ZeRO Stage 3" ) if dist .get_rank () == 0 else None
807797 from deepspeed .runtime .zero .stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -989,7 +979,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
989979
990980 # Communicate only at gradient accumulation boundaries
991981 elif self .is_gradient_accumulation_boundary ():
992- # TODO: communication in fp16 / fp32
993982 if self .zero_optimization_stage (
994983 ) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self .zero_reduce_scatter ():
995984 self .optimizer .reduce_scatter_gradients (
0 commit comments