diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py index c59f668c..80ae4162 100644 --- a/library/adafactor_fused.py +++ b/library/adafactor_fused.py @@ -51,8 +51,6 @@ def copy_kahan_(target: torch.Tensor, source: torch.Tensor, state, update): # Initialize residuals to 0 for first step if state.get('kahan_residuals') is None: state['kahan_residuals'] = torch.zeros_like(source, dtype=torch.int16) - else: - pass # Need this in 32 bit as PyTorch doesn't support mixed 32-bit and 16-bit math operations state['kahan_residuals'] = state['kahan_residuals'].to(source.device).to(dtype=torch.int32)