mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add Validation loss for LoRA training
This commit is contained in:
131
train_network.py
131
train_network.py
@@ -9,6 +9,7 @@ import json
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
import toml
|
||||
import itertools
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -114,7 +115,7 @@ class NetworkTrainer:
|
||||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
):
|
||||
):
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
)
|
||||
@@ -373,10 +374,11 @@ class NetworkTrainer:
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -398,6 +400,11 @@ class NetworkTrainer:
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if val_dataset_group is not None:
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||
|
||||
# acceleratorを準備する
|
||||
@@ -444,6 +451,8 @@ class NetworkTrainer:
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.new_cache_latents(vae, accelerator)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -459,6 +468,8 @@ class NetworkTrainer:
|
||||
if text_encoder_outputs_caching_strategy is not None:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, train_dataset_group, weight_dtype)
|
||||
if val_dataset_group is not None:
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
@@ -567,6 +578,8 @@ class NetworkTrainer:
|
||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||
# some strategies can be None
|
||||
train_dataset_group.set_current_strategies()
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.set_current_strategies()
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
@@ -580,6 +593,17 @@ class NetworkTrainer:
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
cyclic_val_dataloader = itertools.cycle(val_dataloader)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
@@ -592,6 +616,10 @@ class NetworkTrainer:
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# Not for sure here.
|
||||
# if val_dataset_group is not None:
|
||||
# val_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
@@ -1064,7 +1092,11 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
# val_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
del val_dataset_group
|
||||
|
||||
# callback for step start
|
||||
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
||||
@@ -1308,6 +1340,77 @@ class NetworkTrainer:
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
if ((args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps):
|
||||
accelerator.print("\nValidating バリデーション処理...")
|
||||
|
||||
total_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"):
|
||||
batch = next(cyclic_val_dataloader)
|
||||
|
||||
timesteps_list = [10, 350, 500, 650, 990]
|
||||
|
||||
val_loss = 0.0
|
||||
|
||||
for fixed_timesteps in timesteps_list:
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
b_size = latents.shape[0]
|
||||
|
||||
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu")
|
||||
timesteps = timesteps.long().to(latents.device)
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
with accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(False),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
val_loss += loss / len(timesteps_list)
|
||||
|
||||
total_loss += val_loss.detach().item()
|
||||
|
||||
current_val_loss = total_loss / validation_steps
|
||||
# val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_val_loss)
|
||||
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss/current_val_loss": current_val_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# avr_loss: float = val_loss_recorder.moving_average
|
||||
# logs = {"loss/average_val_loss": avr_loss}
|
||||
# accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
@@ -1496,6 +1599,30 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Validation seed / 検証シード"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_every_n_step",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed / 学習エポックごとに検証を行う場合はNoneを指定する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_validation_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset / 検証データセット全体を検証する場合はNoneを指定する"
|
||||
)
|
||||
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||
|
||||
Reference in New Issue
Block a user