Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c7c2063

Browse filesBrowse files
authored
revert
1 parent c814fca commit c7c2063
Copy full SHA for c7c2063

File tree

Expand file treeCollapse file tree

1 file changed

+5
-16
lines changed
Open diff view settings
Filter options
Expand file treeCollapse file tree

1 file changed

+5
-16
lines changed
Open diff view settings
Collapse file

‎deepspeed/runtime/engine.py‎

Copy file name to clipboardExpand all lines: deepspeed/runtime/engine.py
+5-16Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
2222
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
2323
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
24-
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
24+
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer, FP16_FUSED_SUPPORTED_OPTIMIZERS, is_fp16_fused_supported_optimizer
2525
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
2626
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
2727
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
@@ -397,9 +397,6 @@ def zero_gather_fp16_weights_on_model_save(self):
397397
def fp16_enabled(self):
398398
return self._config.fp16_enabled
399399

400-
def precision(self):
401-
return self._config.precision
402-
403400
def amp_enabled(self):
404401
return self._config.amp_enabled
405402

@@ -572,18 +569,14 @@ def is_replicated(p):
572569

573570
for p in self.module.parameters():
574571
if torch.is_tensor(p) and is_replicated(p):
575-
if self.precision() == torch.bfloat16:
576-
p = p.float()
577572
dist.broadcast(p,
578573
self.broadcast_src_rank,
579574
group=self.data_parallel_group)
580-
if self.precision() == torch.bfloat16:
581-
p = p.bfloat16()
582575

583576
def _configure_distributed_model(self, model):
584577
self.module = model
585578
if self.fp16_enabled():
586-
self.module.to(self.precision())
579+
self.module.half()
587580

588581
if not self.dont_change_device:
589582
self.module.to(self.device)
@@ -721,8 +714,7 @@ def _configure_fp16_optimizer(self, optimizer):
721714
initial_dynamic_scale = self.initial_dynamic_scale()
722715
dynamic_loss_args = self.dynamic_loss_scale_args()
723716
clip_grad = self.gradient_clipping()
724-
if isinstance(optimizer,
725-
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
717+
if is_fp16_fused_supported_optimizer(optimizer):
726718
if self.dynamic_loss_scale():
727719
log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
728720
timers = self.timers if self.wall_clock_breakdown() else None
@@ -780,8 +772,7 @@ def _configure_zero_optimizer(self, optimizer):
780772
max_elements_per_comm=self.zero_reduce_bucket_size(),
781773
dp_process_group=self.data_parallel_group,
782774
elastic_checkpoint=self.zero_elastic_checkpoint(),
783-
mpu=self.mpu,
784-
precision=self.precision())
775+
mpu=self.mpu)
785776
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
786777
optimizer = FP16_DeepSpeedZeroOptimizer(
787778
optimizer,
@@ -800,8 +791,7 @@ def _configure_zero_optimizer(self, optimizer):
800791
mpu=self.mpu,
801792
postscale_gradients=self.postscale_gradients(),
802793
gradient_predivide_factor=self.gradient_predivide_factor(),
803-
gradient_accumulation_steps=self.gradient_accumulation_steps(),
804-
precision=self.precision())
794+
gradient_accumulation_steps=self.gradient_accumulation_steps())
805795
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
806796
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
807797
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -989,7 +979,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
989979

990980
# Communicate only at gradient accumulation boundaries
991981
elif self.is_gradient_accumulation_boundary():
992-
# TODO: communication in fp16 / fp32
993982
if self.zero_optimization_stage(
994983
) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self.zero_reduce_scatter():
995984
self.optimizer.reduce_scatter_gradients(

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.