mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add validation loss
This commit is contained in:
@@ -4736,6 +4736,10 @@ class collator_class:
|
|||||||
else:
|
else:
|
||||||
dataset = self.dataset
|
dataset = self.dataset
|
||||||
|
|
||||||
|
# If we split a dataset we will get a Subset
|
||||||
|
if type(dataset) is torch.utils.data.Subset:
|
||||||
|
dataset = dataset.dataset
|
||||||
|
|
||||||
# set epoch and step
|
# set epoch and step
|
||||||
dataset.set_current_epoch(self.current_epoch.value)
|
dataset.set_current_epoch(self.current_epoch.value)
|
||||||
dataset.set_current_step(self.current_step.value)
|
dataset.set_current_step(self.current_step.value)
|
||||||
|
|||||||
117
train_network.py
117
train_network.py
@@ -345,8 +345,21 @@ class NetworkTrainer:
|
|||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
|
||||||
|
if args.validation_ratio > 0.0:
|
||||||
|
train_ratio = 1 - args.validation_ratio
|
||||||
|
validation_ratio = args.validation_ratio
|
||||||
|
train, val = torch.utils.data.random_split(
|
||||||
|
train_dataset_group,
|
||||||
|
[train_ratio, validation_ratio]
|
||||||
|
)
|
||||||
|
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
|
||||||
|
print(f"train images: {len(train)}, validation images: {len(val)}")
|
||||||
|
else:
|
||||||
|
train = train_dataset_group
|
||||||
|
val = []
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset_group,
|
train,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
@@ -354,6 +367,15 @@ class NetworkTrainer:
|
|||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val_dataloader = torch.utils.data.DataLoader(
|
||||||
|
val,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=1,
|
||||||
|
collate_fn=collator,
|
||||||
|
num_workers=n_workers,
|
||||||
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
@@ -711,6 +733,8 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
|
val_loss_recorder = train_util.LossRecorder()
|
||||||
|
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
# callback for step start
|
# callback for step start
|
||||||
@@ -752,6 +776,8 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
|
# TRAINING
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
@@ -877,6 +903,87 @@ class NetworkTrainer:
|
|||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# VALIDATION
|
||||||
|
|
||||||
|
if len(val_dataloader) > 0:
|
||||||
|
print("Validating バリデーション処理...")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
else:
|
||||||
|
# latentに変換
|
||||||
|
latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample()
|
||||||
|
|
||||||
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
|
if torch.any(torch.isnan(latents)):
|
||||||
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
|
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||||
|
latents = latents * self.vae_scale_factor
|
||||||
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
if args.weighted_captions:
|
||||||
|
text_encoder_conds = get_weighted_text_embeddings(
|
||||||
|
tokenizer,
|
||||||
|
text_encoder,
|
||||||
|
batch["captions"],
|
||||||
|
accelerator.device,
|
||||||
|
args.max_token_length // 75 if args.max_token_length else 1,
|
||||||
|
clip_skip=args.clip_skip,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_encoder_conds = self.get_text_cond(
|
||||||
|
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
|
# with noise offset and/or multires noise if specified
|
||||||
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
with accelerator.autocast():
|
||||||
|
noise_pred = self.call_unet(
|
||||||
|
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
target = noise
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||||
|
|
||||||
|
loss = loss * loss_weights
|
||||||
|
|
||||||
|
if args.min_snr_gamma:
|
||||||
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
|
if args.v_pred_like_loss:
|
||||||
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
|
||||||
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
current_loss = loss.detach().item()
|
||||||
|
|
||||||
|
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||||
|
|
||||||
|
if len(val_dataloader) > 0:
|
||||||
|
avr_loss: float = val_loss_recorder.moving_average
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss/validation": avr_loss}
|
||||||
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
@@ -999,6 +1106,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_ratio",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="Ratio for validation images out of the training dataset"
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user