Fused backward pass

This commit is contained in:
2kpr
2024-04-14 09:56:58 -05:00
parent 71e2c91330
commit 4f203ce40d
3 changed files with 142 additions and 6 deletions

106
library/adafactor_fused.py Normal file
View File

@@ -0,0 +1,106 @@
import math
import torch
from transformers import Adafactor
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = Adafactor._rms(p_data_fp32)
lr = Adafactor._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)
@torch.no_grad()
def adafactor_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)
return loss
def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)

View File

@@ -2920,6 +2920,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
default=1, default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
) )
parser.add_argument(
"--fused_backward_pass",
action="store_true",
help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。",
)
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
@@ -3846,6 +3851,14 @@ def get_optimizer(args, trainable_params):
optimizer_type = "AdamW" optimizer_type = "AdamW"
optimizer_type = optimizer_type.lower() optimizer_type = optimizer_type.lower()
if args.fused_backward_pass:
assert (
optimizer_type == "Adafactor".lower()
), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します"
assert (
args.gradient_accumulation_steps == 1
), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません"
# 引数を分解する # 引数を分解する
optimizer_kwargs = {} optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0: if args.optimizer_args is not None and len(args.optimizer_args) > 0:

View File

@@ -430,6 +430,20 @@ def train(args):
text_encoder2 = accelerator.prepare(text_encoder2) text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.fused_backward_pass:
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する # TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -619,6 +633,8 @@ def train(args):
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
accelerator.backward(loss) accelerator.backward(loss)
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = [] params_to_clip = []
for m in training_models: for m in training_models:
@@ -626,6 +642,7 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)