Add validation loss

This commit is contained in:
rockerBOO
2023-11-05 12:35:46 -05:00
parent 95ae56bd22
commit 5b19bda85c
2 changed files with 120 additions and 1 deletions

View File

@@ -4736,6 +4736,10 @@ class collator_class:
else: else:
dataset = self.dataset dataset = self.dataset
# If we split a dataset we will get a Subset
if type(dataset) is torch.utils.data.Subset:
dataset = dataset.dataset
# set epoch and step # set epoch and step
dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value) dataset.set_current_step(self.current_step.value)

View File

@@ -345,8 +345,21 @@ class NetworkTrainer:
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( if args.validation_ratio > 0.0:
train_ratio = 1 - args.validation_ratio
validation_ratio = args.validation_ratio
train, val = torch.utils.data.random_split(
train_dataset_group, train_dataset_group,
[train_ratio, validation_ratio]
)
print(f"split dataset by ratio: train {train_ratio}, validation {validation_ratio}")
print(f"train images: {len(train)}, validation images: {len(val)}")
else:
train = train_dataset_group
val = []
train_dataloader = torch.utils.data.DataLoader(
train,
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collator, collate_fn=collator,
@@ -354,6 +367,15 @@ class NetworkTrainer:
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
val_dataloader = torch.utils.data.DataLoader(
val,
shuffle=False,
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
@@ -711,6 +733,8 @@ class NetworkTrainer:
) )
loss_recorder = train_util.LossRecorder() loss_recorder = train_util.LossRecorder()
val_loss_recorder = train_util.LossRecorder()
del train_dataset_group del train_dataset_group
# callback for step start # callback for step start
@@ -752,6 +776,8 @@ class NetworkTrainer:
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
# TRAINING
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
@@ -877,6 +903,87 @@ class NetworkTrainer:
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
# VALIDATION
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
with torch.no_grad():
for val_step, batch in enumerate(val_dataloader):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
if len(val_dataloader) > 0:
avr_loss: float = val_loss_recorder.moving_average
if args.logging_dir is not None:
logs = {"loss/validation": avr_loss}
accelerator.log(logs, step=epoch + 1)
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average} logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1) accelerator.log(logs, step=epoch + 1)
@@ -999,6 +1106,14 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
) )
parser.add_argument(
"--validation_ratio",
type=float,
default=0.0,
help="Ratio for validation images out of the training dataset"
)
return parser return parser