mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add validation loss
This commit is contained in:
117
train_network.py
117
train_network.py
@@ -345,8 +345,21 @@ class NetworkTrainer:
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
if args.validation_ratio > 0.0:
|
||||
train_ratio = 1 - args.validation_ratio
|
||||
validation_ratio = args.validation_ratio
|
||||
train, val = torch.utils.data.random_split(
|
||||
train_dataset_group,
|
||||
[train_ratio, validation_ratio]
|
||||
)
|
||||
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
|
||||
print(f"train images: {len(train)}, validation images: {len(val)}")
|
||||
else:
|
||||
train = train_dataset_group
|
||||
val = []
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
train,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
@@ -354,6 +367,15 @@ class NetworkTrainer:
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
@@ -711,6 +733,8 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
val_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
del train_dataset_group
|
||||
|
||||
# callback for step start
|
||||
@@ -752,6 +776,8 @@ class NetworkTrainer:
|
||||
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
|
||||
# TRAINING
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
@@ -877,6 +903,87 @@ class NetworkTrainer:
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# VALIDATION
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
print("Validating バリデーション処理...")
|
||||
|
||||
with torch.no_grad():
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample()
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||
latents = latents * self.vae_scale_factor
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
text_encoder_conds = self.get_text_cond(
|
||||
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
||||
)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/validation": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
@@ -999,6 +1106,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--validation_ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Ratio for validation images out of the training dataset"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user