support fused_back_pass for prodigy-plus-schedule-free

This commit is contained in:
some_ai
2025-01-05 23:26:22 -04:00
parent e89653975d
commit 8cee727a99
5 changed files with 817 additions and 56 deletions

View File

@@ -381,6 +381,8 @@ def train(args):
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
if args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree" and args.fused_backward_pass:
args.optimizer_args.append("fused_back_pass=True")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
@@ -473,25 +475,27 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:
logger.warning(
f"Fused backward pass is not supported for optimizer type: {args.optimizer_type}. Ignoring."
)
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers

View File

@@ -0,0 +1,745 @@
import math
import torch
from statistics import harmonic_mean
from .prodigy_plus_schedulefree import ProdigyPlusScheduleFree
class CoreOptimiser(torch.optim.Optimizer):
def __init__(self, params, lr=1.0,
betas=(0.9, 0.99), beta3=None,
weight_decay=0.0,
weight_decay_by_lr=True,
use_bias_correction=False,
d0=1e-6, d_coef=1.0,
prodigy_steps=0,
eps=1e-8,
split_groups=True,
split_groups_mean=True,
factored=True,
fused_back_pass=False,
use_stableadamw=True,
use_muon_pp=False,
use_cautious=False,
use_adopt=False,
stochastic_rounding=True):
if not 0.0 < d0:
raise ValueError("Invalid d0 value: {}".format(d0))
if not 0.0 < lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if eps is not None and not 0.0 < eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if beta3 is not None and not 0.0 <= beta3 < 1.0:
raise ValueError("Invalid beta3 parameter: {}".format(beta3))
if beta3 is None:
beta3 = betas[1] ** 0.5
defaults = dict(lr=lr, betas=betas, beta3=beta3,
eps=eps,
weight_decay=weight_decay,
weight_decay_by_lr=weight_decay_by_lr,
d=d0, d0=d0, d_coef=d_coef,
k=1, train_mode=True,
weight_sum=0,
prodigy_steps=prodigy_steps,
use_bias_correction=use_bias_correction,
d_numerator=0.0,
d_denom=0,
factored=factored,
use_stableadamw=use_stableadamw,
use_muon_pp=use_muon_pp,
use_cautious=use_cautious,
use_adopt=use_adopt,
stochastic_rounding=stochastic_rounding)
super().__init__(params, defaults)
self.d0 = d0
if split_groups and len(self.param_groups) == 1:
print(f"[{self.__class__.__name__}] Optimiser contains single param_group -- 'split_groups' has been disabled.")
split_groups = False
self.split_groups = split_groups
self.split_groups_mean = split_groups_mean
# Properties for fused backward pass.
self.groups_to_process = None
self.shared_d = None
self.fused_back_pass = fused_back_pass
# Use tensors to keep everything on device during parameter loop.
for group in (self.param_groups if self.split_groups else self.param_groups[:1]):
p = group['params'][0]
group['running_d_numerator'] = torch.tensor(0.0, dtype=torch.float32, device=p.device)
group['running_d_denom'] = torch.tensor(0.0, dtype=torch.float32, device=p.device)
@torch.no_grad()
def eval(self):
pass
@torch.no_grad()
def train(self):
pass
@property
def supports_memory_efficient_fp16(self):
return False
@property
def supports_flat_params(self):
return True
def supports_fused_back_pass(self):
return True
@torch.no_grad()
def get_sliced_tensor(self, tensor, slice_p=11):
return tensor.ravel()[::slice_p]
@torch.no_grad()
def get_running_values_for_group(self, group):
if not self.split_groups:
group = self.param_groups[0]
return group['running_d_numerator'], group['running_d_denom']
@torch.no_grad()
def get_d_mean(self):
if self.split_groups and self.split_groups_mean:
return harmonic_mean(group['d'] for group in self.param_groups)
return None
@torch.no_grad()
def get_d_max(self, group):
if self.split_groups:
return max(group['d'] for group in self.param_groups)
return group['d']
# From: https://github.com/KellerJordan/Muon/blob/master/muon.py
@torch.no_grad()
def newton_schulz_(self, G, steps=5, eps=1e-7):
# Inline reshaping step within the method itself.
X = G.view(G.size(0), -1)
a, b, c = (3.4445, -4.7750, 2.0315)
X = X.to(dtype=torch.bfloat16, copy=True)
X /= (X.norm() + eps) # ensure top singular value <= 1
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(0) > G.size(1):
X = X.T
G.copy_(X.view_as(G))
del X
return G
# Implementation by Nerogar. From: https://github.com/pytorch/pytorch/issues/120376#issuecomment-1974828905
def copy_stochastic_(self, target, source):
# create a random 16 bit integer
result = torch.randint_like(
source,
dtype=torch.int32,
low=0,
high=(1 << 16),
)
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
# Modified Adafactor factorisation implementation by Ross Wightman
# https://github.com/huggingface/pytorch-image-models/pull/2320
@torch.no_grad()
def factored_dims(self,
shape,
factored,
min_dim_size_to_factor):
r"""Whether to use a factored second moment estimator.
This function returns a tuple with the two largest axes to reduce over.
If all dimensions have size < min_dim_size_to_factor, return None.
Args:
shape: an input shape
factored: whether to use factored second-moment estimator for > 2d vars.
min_dim_size_to_factor: only factor accumulator if all array dimensions are greater than this size.
Returns:
None or a tuple of ints
"""
if not factored or len(shape) < 2:
return None
if all(dim < min_dim_size_to_factor for dim in shape):
return None
sorted_dims = sorted(((x, i) for i, x in enumerate(shape)))
return int(sorted_dims[-2][1]), int(sorted_dims[-1][1])
@torch.no_grad()
def initialise_state(self, p, group):
raise Exception("Not implemented!")
@torch.no_grad()
def initialise_state_internal(self, p, group):
state = self.state[p]
needs_init = len(state) == 0
if needs_init:
grad = p.grad
dtype = torch.bfloat16 if p.dtype == torch.float32 else p.dtype
sliced_data = self.get_sliced_tensor(p)
# NOTE: We don't initialise z/exp_avg here -- subclass needs to do that.
state['muon'] = group['use_muon_pp'] and len(grad.shape) >= 2 and grad.size(0) < 10000
if not state['muon']:
factored_dims = self.factored_dims(
grad.shape,
factored=group['factored'],
min_dim_size_to_factor=32
)
if factored_dims is not None:
dc, dr = factored_dims
row_shape = list(p.grad.shape)
row_shape[dr] = 1
col_shape = list(p.grad.shape)
col_shape[dc] = 1
reduce_dc = dc - 1 if dc > dr else dc
# Store reduction variables so we don't have to recalculate each step.
# Always store second moment low ranks in fp32 to avoid precision issues. Memory difference
# between bf16/fp16 and fp32 is negligible here.
state["exp_avg_sq"] = [torch.zeros(row_shape, dtype=torch.float32, device=p.device).detach(),
torch.zeros(col_shape, dtype=torch.float32, device=p.device).detach(),
dr, dc, reduce_dc]
else:
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format).detach()
# If the initial weights are zero, don't bother storing them.
if p.any() > 0:
state['p0'] = sliced_data.to(dtype=dtype, memory_format=torch.preserve_format, copy=True).detach()
else:
state['p0'] = torch.tensor(0.0, dtype=dtype, device=p.device)
state['s'] = torch.zeros_like(sliced_data, memory_format=torch.preserve_format, dtype=dtype).detach()
return state, needs_init
@torch.no_grad()
def update_d_and_reset(self, group):
k = group['k']
prodigy_steps = group['prodigy_steps']
if prodigy_steps > 0 and k >= prodigy_steps:
return
d = group['d']
d0 = group['d0']
d_coef = group['d_coef']
beta3 = group['beta3']
running_d_numerator, running_d_denom = self.get_running_values_for_group(group)
d_numerator = group['d_numerator']
d_numerator *= beta3
d_numerator_item = running_d_numerator.item()
d_denom_item = running_d_denom.item()
# Prevent the accumulation of negative values in the numerator in early training.
# We still allow negative updates once progress starts being made, as this is
# important for regulating the adaptive stepsize.
if d_numerator_item > 0 or d > d0:
# Force Prodigy to be extremely confident before increasing the LR when
# gradient and weights are aligned.
d_numerator = 0 if d_numerator_item < 0 else d_numerator + d_numerator_item
if d_denom_item > 0:
d = max(math.atan2(d_coef * d_numerator, d_denom_item), d)
group['d'] = d
group['d_numerator'] = d_numerator
# Used for logging purposes only.
group['d_denom'] = d_denom_item
running_d_numerator.zero_()
running_d_denom.zero_()
def on_start_step(self, group):
if self.groups_to_process is None:
# Optimiser hasn't run yet, so initialise.
self.groups_to_process = {i: len(group['params']) for i, group in enumerate(self.param_groups)}
elif len(self.groups_to_process) == 0:
# Start of new optimiser run, so grab updated d.
self.groups_to_process = {i: len(group['params']) for i, group in enumerate(self.param_groups)}
if not self.split_groups:
# When groups aren't split, calculate d for the first group,
# then copy to all other groups.
self.update_d_and_reset(group)
for g in self.param_groups:
g['d'] = group['d']
self.shared_d = self.get_d_mean()
def on_end_step(self, group):
group_index = self.param_groups.index(group)
# Decrement params processed so far.
self.groups_to_process[group_index] -= 1
# End of param loop for group, update calculations.
if self.groups_to_process[group_index] == 0:
k = group['k']
prodigy_steps = group['prodigy_steps']
if prodigy_steps > 0 and k == prodigy_steps:
print(f"[{self.__class__.__name__}] Prodigy stepsize adaptation disabled after {k} steps for param_group {group_index}.")
self.groups_to_process.pop(group_index)
if self.split_groups: # When groups are split, calculate per-group d.
self.update_d_and_reset(group)
group['k'] = k + 1
return True
return False
def get_dlr(self, group):
lr = group['lr']
dlr = (self.shared_d if self.split_groups and self.shared_d else group['d']) * lr
dlr = max(dlr, group['d0'])
return dlr
def update_prodigy(self, state, group, grad, data):
k = group['k']
prodigy_steps = group['prodigy_steps']
if prodigy_steps <= 0 or k < prodigy_steps:
d, d0 = group['d'], group['d0']
beta3 = group['beta3']
sliced_grad = self.get_sliced_tensor(grad)
sliced_data = self.get_sliced_tensor(data).float()
# Rescale Prodigy updates to compensate for the lost elements due to slicing.
# This is mostly for logging purposes and shouldn't change the functionality.
slice_scale = grad.numel() / sliced_grad.numel()
running_d_numerator, running_d_denom = self.get_running_values_for_group(group)
s = state['s']
x0_minus = state['p0'] - sliced_data
running_d_numerator.add_(torch.dot(sliced_grad, x0_minus), alpha=(d / d0) * d * slice_scale)
s.mul_(beta3).add_(sliced_grad, alpha=(d / d0) * d * slice_scale)
running_d_denom.add_(s.abs().sum())
del x0_minus
elif 's' in state: # Free the memory used by Prodigy, as we no longer need it.
del state['s']
del state['p0']
def update_(self, num, denom, group):
d = group['d']
eps = group['eps']
# Deviate from reference Prodigy -- rather than apply d to the EMAs,
# apply directly before calculating the update to simplify the rest of the optimiser.
num.mul_(d)
denom.mul_(d * d)
if eps is None:
# Adam-atan2. Use atan2 rather than epsilon and division
# for parameter updates (https://arxiv.org/abs/2407.05872).
# Has the nice property of "clipping" the gradient as well.
update = num.atan2_(denom)
else:
update = num.div_(denom.add_(d * eps))
return update
def get_denom(self, state):
exp_avg_sq = state['exp_avg_sq']
# Adam EMA updates
if isinstance(exp_avg_sq, list):
row_var, col_var, _, _, reduce_dc = exp_avg_sq
row_col_mean = row_var.mean(dim=reduce_dc, keepdim=True).add_(1e-30)
row_factor = row_var.div(row_col_mean).sqrt_()
col_factor = col_var.sqrt()
denom = row_factor * col_factor
else:
denom = exp_avg_sq.sqrt()
return denom
def update_first_moment(self, state, group, grad):
exp_avg = state['exp_avg']
beta1, _ = group['betas']
return exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
def update_second_moment(self, state, group, grad, beta2, return_denom=True, denom_before_update=False):
exp_avg_sq = state['exp_avg_sq']
denom = None
if return_denom and denom_before_update:
denom = self.get_denom(state)
# Adafactor / PaLM beta2 decay. Clip beta2 as per Scaling ViT paper.
if group['use_bias_correction']:
k = group['k']
debias = 1 - k ** -0.8
beta2 = min(beta2, 1 - ((1 - debias) / (1 - debias ** k)))
# Adam EMA updates
if isinstance(exp_avg_sq, list):
row_var, col_var, dr, dc, _ = exp_avg_sq
row_var.lerp_(
grad.norm(dim=dr, keepdim=True).square_().div_(grad.shape[dr]),
weight=1 - beta2
)
col_var.lerp_(
grad.norm(dim=dc, keepdim=True).square_().div_(grad.shape[dc]),
weight=1 - beta2
)
else:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if return_denom and denom is None:
denom = self.get_denom(state)
return denom
def rms_(self, tensor, rms_min):
rms = tensor.norm().div(tensor.numel() ** 0.5).clamp_min(rms_min)
return tensor.div_(rms)
# "Cautious Optimizer (C-Optim): Improving Training with One Line of Code"
# https://github.com/kyleliang919/c-optim
# Modified version by Ross Wightman: https://x.com/wightmanr/status/1862226848475955442
def cautious_(self, update, grad):
mask = (grad * update > 0).to(dtype=grad.dtype)
mask.div_(mask.mean().clamp_min(1e-3))
update.mul_(mask)
del mask
return update
def cautious_mask(self, update, grad):
mask = (grad * update > 0).to(dtype=grad.dtype)
mask.div_(mask.mean().clamp_min(1e-3))
return mask
@torch.no_grad()
def step_param(self, p, group):
raise Exception("Not implemented!")
@torch.no_grad()
def step_parameter(self, p, group, i):
self.step_param(p, group)
@torch.no_grad()
def step(self, closure=None):
if self.fused_back_pass:
return
"""Performs a single optimisation step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for param_group in self.param_groups:
for p in param_group["params"]:
self.step_param(p, param_group)
return loss
class ProdigyPlusScheduleFree(CoreOptimiser):
r"""
An optimiser based on Prodigy that includes schedule-free logic. Has additional improvements in the form of optional StableAdamW
gradient scaling and Adam-atan2 updates, per parameter group adaptation, lower memory utilisation, fused back pass support and
tweaks to mitigate uncontrolled LR growth.
Based on code from:
https://github.com/facebookresearch/schedule_free
https://github.com/konstmish/prodigy
Incorporates improvements from these pull requests (credit to https://github.com/dxqbYD and https://github.com/sangoi-exe):
https://github.com/konstmish/prodigy/pull/23
https://github.com/konstmish/prodigy/pull/22
https://github.com/konstmish/prodigy/pull/20
As with the reference implementation of schedule-free, a constant scheduler should be used, along with the appropriate
calls to `train()` and `eval()`. See the schedule-free documentation for more details: https://github.com/facebookresearch/schedule_free
If you do use another scheduler, linear or cosine is preferred, as a restarting scheduler can confuse Prodigy's adaptation logic.
Leave `lr` set to 1 unless you encounter instability. Do not use with gradient clipping, as this can hamper the
ability for the optimiser to predict stepsizes. Gradient clipping/normalisation is already handled in the following configurations:
1) `use_stableadamw=True,eps=1e8` (or any reasonable positive epsilon)
2) `eps=None` (Adam-atan2, scale invariant and can mess with Prodigy's stepsize calculations in some scenarios)
By default, `split_groups` is set to `True`, so each parameter group will have its own adaptation values. So if you're training
different networks together, they won't contaminate each other's learning rates. The disadvantage of this approach is that some
networks can take a long time to reach a good learning rate when trained alongside others (for example, SDXL's Unet).
It's recommended to use a higher `d0` (1e-5, 5e-5, 1e-4) so these networks don't get stuck at a low learning rate.
For Prodigy's reference behaviour, which lumps all parameter groups together, set `split_groups` to `False`.
In some scenarios, it can be advantageous to freeze Prodigy's adaptive stepsize after a certain number of steps. This
can be controlled via the `prodigy_steps` settings. This will also free any Prodigy-specific memory used by the
optimiser (though with all the memory-related improvements, this should not be significant unless you're training
very large models).
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate.
(default: 1.0)
betas (Tuple[float, float], optional):
Coefficients used for computing running averages of gradient and its square.
(default: (0.9, 0.99))
eps (float):
Term added to the denominator outside of the root operation to improve numerical stability. If set to None,
Adam-atan2 is used instead. This removes the need for epsilon tuning, but may not work well in all situations.
(default: 1e-8).
beta3 (float):
Coefficient for computing the Prodigy stepsize using running averages. If set to None, uses the value of
square root of beta2
(default: None).
weight_decay (float):
Decoupled weight decay. Use the weight_decay_by_lr setting to determine if decay should be multiplied by the
adaptive learning rate.
(default: 0).
weight_decay_by_lr (boolean):
If True, weight_decay is multiplied by the adaptive learning rate (as per the PyTorch implementation of AdamW).
If False, weight_decay will have a much stronger effect.
(default: True).
use_bias_correction (boolean):
Turn on Adafactor-style bias correction, which scales beta2 directly. (default: False).
d0 (float):
Initial estimate for Prodigy. Also serves as the minimum learning rate.
(default: 1e-6).
d_coef (float):
Coefficient in the expression for the estimate of d. Values such as 0.5 and 2.0 typically work as well.
Changing this parameter is the preferred way to tune the method.
(default: 1.0)
prodigy_steps (int):
Freeze Prodigy stepsize adjustments after a certain optimiser step and releases all state memory required
by Prodigy.
(default: 0)
split_groups (boolean):
Track individual adaptation values for each parameter group. For example, if training
a text encoder beside a Unet. Note this can have a significant impact on training dynamics.
Set to False for original Prodigy behaviour, where all groups share the same values.
(default: True)
split_groups_mean (boolean):
When split_groups is True, use the harmonic mean of learning rates for all groups. This favours
a more conservative LR. Calculation remains per-group. If split_groups is False, this value has no effect.
Set to False to have each group use its own learning rate.
(default: True)
factored (boolean):
Use factored approximation of the second moment, similar to Adafactor. Reduces memory usage. Disable
if training results in NaNs or the learning rate fails to grow.
(default: True)
fused_back_pass (boolean):
Stops the optimiser from running the normal step method. Set to True if using fused backward pass.
(default: False)
use_stableadamw (boolean):
Scales parameter updates by the root-mean-square of the normalised gradient, in essence identical to
Adafactor's gradient scaling. Set to False if the adaptive learning rate never improves.
(default: True)
use_muon_pp (boolean):
Experimental. Perform orthogonalisation on the gradient before it is used for updates ala Shampoo/SOAP/Muon.
(https://github.com/KellerJordan/Muon/blob/master/muon.py). Not suitable for all training scenarios.
May not work well with small batch sizes or finetuning.
(default: False)
use_cautious (boolean):
Experimental. Perform "cautious" updates, as proposed in https://arxiv.org/pdf/2411.16085. Modifies
the update to isolate and boost values that align with the current gradient. Note that we do not have
access to a first moment, so this deviates from the paper (we apply the mask directly to the update).
May have a limited effect.
(default: False)
use_adopt (boolean):
Experimental. Performs a modified step where the second moment is updated after the parameter update,
so as not to include the current gradient in the denominator. This is a partial implementation of ADOPT
(https://arxiv.org/abs/2411.02853), as we don't have a first moment to use for the update.
(default: False)
stochastic_rounding (boolean):
Use stochastic rounding for bfloat16 weights (https://github.com/pytorch/pytorch/issues/120376). Brings
bfloat16 training performance close to that of float32.
(default: True)
"""
def __init__(self, params, lr=1.0,
betas=(0.9, 0.99), beta3=None,
weight_decay=0.0,
weight_decay_by_lr=True,
use_bias_correction=False,
d0=1e-6, d_coef=1.0,
prodigy_steps=0,
eps=1e-8,
split_groups=True,
split_groups_mean=True,
factored=True,
fused_back_pass=False,
use_stableadamw=True,
use_muon_pp=False,
use_cautious=False,
use_adopt=False,
stochastic_rounding=True):
super().__init__(params=params, lr=lr, betas=betas, beta3=beta3,
weight_decay=weight_decay, weight_decay_by_lr=weight_decay_by_lr,
use_bias_correction=use_bias_correction,
d0=d0, d_coef=d_coef, prodigy_steps=prodigy_steps,
eps=eps, split_groups=split_groups,
split_groups_mean=split_groups_mean, factored=factored,
fused_back_pass=fused_back_pass, use_stableadamw=use_stableadamw,
use_muon_pp=use_muon_pp, use_cautious=use_cautious, use_adopt=use_adopt,
stochastic_rounding=stochastic_rounding)
@torch.no_grad()
def eval(self):
for group in self.param_groups:
if not group['train_mode']:
continue
beta1, _ = group['betas']
for p in group['params']:
z = self.state[p].get('z')
if z is not None:
# Set p to x
p.lerp_(end=z.to(device=p.device), weight=1 - 1 / beta1)
group['train_mode'] = False
@torch.no_grad()
def train(self):
for group in self.param_groups:
if group['train_mode']:
continue
beta1, _ = group['betas']
for p in group['params']:
z = self.state[p].get('z')
if z is not None:
# Set p to y
p.lerp_(end=z.to(device=p.device), weight=1 - beta1)
group['train_mode'] = True
@torch.no_grad()
def initialise_state(self, p, group):
state, needs_init = self.initialise_state_internal(p, group)
if needs_init:
state['z'] = p.detach().clone(memory_format=torch.preserve_format)
return state
@torch.no_grad()
def update_params(self, y, z, update, group):
dlr = self.get_dlr(group)
decay = group['weight_decay']
beta1, _ = group['betas']
weight = self.get_d_max(group) ** 2
weight_sum = group['weight_sum'] + weight
ckp1 = weight / weight_sum if weight_sum else 0
y.lerp_(end=z, weight=ckp1)
# Weight decay at Y.
if decay != 0:
if group['weight_decay_by_lr']:
decay *= dlr
y.sub_(y, alpha=decay * (1 - beta1))
z.sub_(y, alpha=decay)
y.add_(update, alpha=dlr * (beta1 * (1 - ckp1) - 1))
z.sub_(update, alpha=dlr)
return weight_sum
@torch.no_grad()
def step_param(self, p, group):
if not group['train_mode']:
raise Exception("Not in train mode!")
self.on_start_step(group)
weight_sum = group['weight_sum']
if p.grad is not None:
grad = p.grad.float()
grad_mask = grad.clone() if group['use_cautious'] else None
rms_min = 1.0 if group['use_stableadamw'] else None
state = self.initialise_state(p, group)
y, z = p, state['z']
self.update_prodigy(state, group, grad, z)
update = None
if state['muon']:
rms_min = 1e-7
# Use high epsilon at start of training so
# Prodigy doesn't take forever to adapt the stepsize.
eps = max(rms_min, 0.2 ** group['k'] ** 0.5)
update = self.newton_schulz_(grad, eps=eps)
else:
use_adopt = group['use_adopt']
if use_adopt and group['k'] == 1:
self.update_second_moment(state, group, grad, 0, return_denom=False)
else:
_, beta2 = group['betas']
denom = self.update_second_moment(state, group, grad, beta2, denom_before_update=use_adopt)
update = self.update_(grad, denom, group)
del denom
if update is not None:
if rms_min is not None:
self.rms_(update, rms_min)
if grad_mask is not None:
self.cautious_(update, grad_mask)
del grad_mask
if group['stochastic_rounding'] and y.dtype == z.dtype == torch.bfloat16:
y_fp32, z_fp32 = y.float(), z.float()
weight_sum = self.update_params(y_fp32, z_fp32, update, group)
self.copy_stochastic_(y, y_fp32)
self.copy_stochastic_(z, z_fp32)
del y_fp32, z_fp32
else:
weight_sum = self.update_params(y, z, update, group)
del update
if self.on_end_step(group):
group['weight_sum'] = weight_sum

View File

@@ -4631,9 +4631,9 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
""" assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" """
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"

View File

@@ -482,6 +482,8 @@ def train(args):
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
if args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree" and args.fused_backward_pass:
args.optimizer_args.append("fused_back_pass=True")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
@@ -597,27 +599,29 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:
logger.warning(
f"Fused backward pass is not supported for optimizer type: {args.optimizer_type}. Ignoring."
)
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.blockwise_fused_optimizers:
if args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])

View File

@@ -415,6 +415,8 @@ def train(args):
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
else:
if args.optimizer_type == "prodigyplus.ProdigyPlusScheduleFree" and args.fused_backward_pass:
args.optimizer_args.append("fused_back_pass=True")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# prepare dataloader
@@ -520,21 +522,27 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
if args.optimizer_type == "AdaFactor":
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.optimizer_type == "ProdigyPlusScheduleFree":
# ProdigyPlus uses its internal fused_back_pass mechanism, pass for now
pass
else:
logger.warning(
f"Fused backward pass is not supported for optimizer type: {args.optimizer_type}. Ignoring."
)
elif args.fused_optimizer_groups:
# prepare for additional optimizers and lr schedulers