mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
chore: Refactor optimizer group
This commit is contained in:
@@ -357,27 +357,37 @@ def train(args):
|
|||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
if args.fused_optimizer_groups:
|
if args.fused_optimizer_groups:
|
||||||
|
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||||
|
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
|
||||||
|
# This balances memory usage and management complexity.
|
||||||
|
|
||||||
# calculate total number of parameters
|
# calculate total number of parameters
|
||||||
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
|
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
|
||||||
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
|
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
|
||||||
|
|
||||||
# split params into groups
|
# split params into groups, keeping the learning rate the same for all params in a group
|
||||||
|
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
|
||||||
grouped_params = []
|
grouped_params = []
|
||||||
param_group = []
|
param_group = []
|
||||||
param_group_lr = -1
|
param_group_lr = -1
|
||||||
for group in params_to_optimize:
|
for group in params_to_optimize:
|
||||||
lr = group["lr"]
|
lr = group["lr"]
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
|
# if the learning rate is different for different params, start a new group
|
||||||
if lr != param_group_lr:
|
if lr != param_group_lr:
|
||||||
if param_group:
|
if param_group:
|
||||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||||
param_group = []
|
param_group = []
|
||||||
param_group_lr = lr
|
param_group_lr = lr
|
||||||
|
|
||||||
param_group.append(p)
|
param_group.append(p)
|
||||||
|
|
||||||
|
# if the group has enough parameters, start a new group
|
||||||
if len(param_group) == params_per_group:
|
if len(param_group) == params_per_group:
|
||||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||||
param_group = []
|
param_group = []
|
||||||
param_group_lr = -1
|
param_group_lr = -1
|
||||||
|
|
||||||
if param_group:
|
if param_group:
|
||||||
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||||
|
|
||||||
@@ -388,7 +398,6 @@ def train(args):
|
|||||||
optimizers.append(optimizer)
|
optimizers.append(optimizer)
|
||||||
optimizer = optimizers[0] # avoid error in the following code
|
optimizer = optimizers[0] # avoid error in the following code
|
||||||
|
|
||||||
print(len(grouped_params))
|
|
||||||
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
|
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -420,6 +429,7 @@ def train(args):
|
|||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
if args.fused_optimizer_groups:
|
if args.fused_optimizer_groups:
|
||||||
|
# prepare lr schedulers for each optimizer
|
||||||
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
||||||
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||||
else:
|
else:
|
||||||
@@ -472,6 +482,7 @@ def train(args):
|
|||||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
if args.fused_backward_pass:
|
if args.fused_backward_pass:
|
||||||
|
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||||
import library.adafactor_fused
|
import library.adafactor_fused
|
||||||
|
|
||||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
@@ -488,16 +499,20 @@ def train(args):
|
|||||||
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
||||||
|
|
||||||
elif args.fused_optimizer_groups:
|
elif args.fused_optimizer_groups:
|
||||||
|
# prepare for additional optimizers and lr schedulers
|
||||||
for i in range(1, len(optimizers)):
|
for i in range(1, len(optimizers)):
|
||||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||||
|
|
||||||
|
# counters are used to determine when to step the optimizer
|
||||||
global optimizer_hooked_count
|
global optimizer_hooked_count
|
||||||
global num_parameters_per_group
|
global num_parameters_per_group
|
||||||
global parameter_optimizer_map
|
global parameter_optimizer_map
|
||||||
|
|
||||||
optimizer_hooked_count = {}
|
optimizer_hooked_count = {}
|
||||||
num_parameters_per_group = [0] * len(optimizers)
|
num_parameters_per_group = [0] * len(optimizers)
|
||||||
parameter_optimizer_map = {}
|
parameter_optimizer_map = {}
|
||||||
|
|
||||||
for opt_idx, optimizer in enumerate(optimizers):
|
for opt_idx, optimizer in enumerate(optimizers):
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
for parameter in param_group["params"]:
|
for parameter in param_group["params"]:
|
||||||
@@ -511,7 +526,7 @@ def train(args):
|
|||||||
optimizer_hooked_count[i] += 1
|
optimizer_hooked_count[i] += 1
|
||||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||||
optimizers[i].step()
|
optimizers[i].step()
|
||||||
optimizers[i].zero_grad()
|
optimizers[i].zero_grad(set_to_none=True)
|
||||||
|
|
||||||
parameter.register_post_accumulate_grad_hook(optimizer_hook)
|
parameter.register_post_accumulate_grad_hook(optimizer_hook)
|
||||||
parameter_optimizer_map[parameter] = opt_idx
|
parameter_optimizer_map[parameter] = opt_idx
|
||||||
@@ -593,7 +608,7 @@ def train(args):
|
|||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
|
|
||||||
if args.fused_optimizer_groups:
|
if args.fused_optimizer_groups:
|
||||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))}
|
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||||
|
|
||||||
with accelerator.accumulate(*training_models):
|
with accelerator.accumulate(*training_models):
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
@@ -725,15 +740,15 @@ def train(args):
|
|||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
elif args.fused_optimizer_groups:
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
else:
|
||||||
|
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||||
|
lr_scheduler.step()
|
||||||
|
if args.fused_optimizer_groups:
|
||||||
for i in range(1, len(optimizers)):
|
for i in range(1, len(optimizers)):
|
||||||
lr_schedulers[i].step()
|
lr_schedulers[i].step()
|
||||||
|
|
||||||
lr_scheduler.step()
|
|
||||||
|
|
||||||
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user