mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add experimental option to fuse params to optimizer groups
This commit is contained in:
114
sdxl_train.py
114
sdxl_train.py
@@ -345,8 +345,8 @@ def train(args):
|
|||||||
|
|
||||||
# calculate number of trainable parameters
|
# calculate number of trainable parameters
|
||||||
n_params = 0
|
n_params = 0
|
||||||
for params in params_to_optimize:
|
for group in params_to_optimize:
|
||||||
for p in params["params"]:
|
for p in group["params"]:
|
||||||
n_params += p.numel()
|
n_params += p.numel()
|
||||||
|
|
||||||
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
||||||
@@ -355,7 +355,44 @@ def train(args):
|
|||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
|
||||||
|
if args.fused_optimizer_groups:
|
||||||
|
# calculate total number of parameters
|
||||||
|
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
|
||||||
|
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
|
||||||
|
|
||||||
|
# split params into groups
|
||||||
|
grouped_params = []
|
||||||
|
param_group = []
|
||||||
|
param_group_lr = -1
|
||||||
|
for group in params_to_optimize:
|
||||||
|
lr = group["lr"]
|
||||||
|
for p in group["params"]:
|
||||||
|
if lr != param_group_lr:
|
||||||
|
if param_group:
|
||||||
|
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||||
|
param_group = []
|
||||||
|
param_group_lr = lr
|
||||||
|
param_group.append(p)
|
||||||
|
if len(param_group) == params_per_group:
|
||||||
|
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||||
|
param_group = []
|
||||||
|
param_group_lr = -1
|
||||||
|
if param_group:
|
||||||
|
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
||||||
|
|
||||||
|
# prepare optimizers for each group
|
||||||
|
optimizers = []
|
||||||
|
for group in grouped_params:
|
||||||
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||||
|
optimizers.append(optimizer)
|
||||||
|
optimizer = optimizers[0] # avoid error in the following code
|
||||||
|
|
||||||
|
print(len(grouped_params))
|
||||||
|
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
|
||||||
|
|
||||||
|
else:
|
||||||
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
@@ -382,7 +419,11 @@ def train(args):
|
|||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
if args.fused_optimizer_groups:
|
||||||
|
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
||||||
|
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
||||||
|
else:
|
||||||
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@@ -432,10 +473,12 @@ def train(args):
|
|||||||
|
|
||||||
if args.fused_backward_pass:
|
if args.fused_backward_pass:
|
||||||
import library.adafactor_fused
|
import library.adafactor_fused
|
||||||
|
|
||||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
for param_group in optimizer.param_groups:
|
for param_group in optimizer.param_groups:
|
||||||
for parameter in param_group["params"]:
|
for parameter in param_group["params"]:
|
||||||
if parameter.requires_grad:
|
if parameter.requires_grad:
|
||||||
|
|
||||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||||
@@ -444,6 +487,36 @@ def train(args):
|
|||||||
|
|
||||||
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
||||||
|
|
||||||
|
elif args.fused_optimizer_groups:
|
||||||
|
for i in range(1, len(optimizers)):
|
||||||
|
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||||
|
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||||
|
|
||||||
|
global optimizer_hooked_count
|
||||||
|
global num_parameters_per_group
|
||||||
|
global parameter_optimizer_map
|
||||||
|
optimizer_hooked_count = {}
|
||||||
|
num_parameters_per_group = [0] * len(optimizers)
|
||||||
|
parameter_optimizer_map = {}
|
||||||
|
for opt_idx, optimizer in enumerate(optimizers):
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
for parameter in param_group["params"]:
|
||||||
|
if parameter.requires_grad:
|
||||||
|
|
||||||
|
def optimizer_hook(parameter: torch.Tensor):
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||||
|
|
||||||
|
i = parameter_optimizer_map[parameter]
|
||||||
|
optimizer_hooked_count[i] += 1
|
||||||
|
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||||
|
optimizers[i].step()
|
||||||
|
optimizers[i].zero_grad()
|
||||||
|
|
||||||
|
parameter.register_post_accumulate_grad_hook(optimizer_hook)
|
||||||
|
parameter_optimizer_map[parameter] = opt_idx
|
||||||
|
num_parameters_per_group[opt_idx] += 1
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
@@ -518,6 +591,10 @@ def train(args):
|
|||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
|
|
||||||
|
if args.fused_optimizer_groups:
|
||||||
|
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))}
|
||||||
|
|
||||||
with accelerator.accumulate(*training_models):
|
with accelerator.accumulate(*training_models):
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||||
@@ -596,7 +673,9 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -614,7 +693,9 @@ def train(args):
|
|||||||
or args.masked_loss
|
or args.masked_loss
|
||||||
):
|
):
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -630,11 +711,13 @@ def train(args):
|
|||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
|
|
||||||
if not args.fused_backward_pass:
|
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
params_to_clip = []
|
params_to_clip = []
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
@@ -642,9 +725,14 @@ def train(args):
|
|||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
elif args.fused_optimizer_groups:
|
||||||
|
for i in range(1, len(optimizers)):
|
||||||
|
lr_schedulers[i].step()
|
||||||
|
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
|
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
@@ -753,7 +841,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state or args.save_state_on_train_end:
|
if args.save_state or args.save_state_on_train_end:
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
@@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
|
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
|
||||||
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fused_optimizer_groups",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user