mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Suppor LR graphs for each block, base lr
This commit is contained in:
@@ -32,16 +32,31 @@ from library.custom_train_functions import apply_snr_weight
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = float(lrs[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/block{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower():
|
||||
logs[f"lr/d*lr/block{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
return logs
|
||||
|
||||
@@ -99,10 +114,10 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
@@ -146,7 +161,6 @@ def train(args):
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@@ -214,7 +228,9 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
if is_main_process:
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
@@ -518,7 +534,7 @@ def train(args):
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch+1
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user