mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Split val latents/batch and pick up val latents shape size which equal to training batch.
This commit is contained in:
@@ -1349,7 +1349,27 @@ class NetworkTrainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||||
for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"):
|
for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"):
|
||||||
batch = next(cyclic_val_dataloader)
|
|
||||||
|
while True:
|
||||||
|
val_batch = next(cyclic_val_dataloader)
|
||||||
|
|
||||||
|
if "latents" in val_batch and val_batch["latents"] is not None:
|
||||||
|
val_latents = val_batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
# latentに変換
|
||||||
|
val_latents = self.encode_images_to_latents(args, accelerator, vae, val_batch["images"].to(vae_dtype))
|
||||||
|
val_latents = val_latents.to(dtype=weight_dtype)
|
||||||
|
|
||||||
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
|
if torch.any(torch.isnan(val_latents)):
|
||||||
|
accelerator.print("NaN found in validation latents, replacing with zeros")
|
||||||
|
val_latents = torch.nan_to_num(val_latents, 0, out=val_latents)
|
||||||
|
|
||||||
|
val_latents = self.shift_scale_latents(args, val_latents)
|
||||||
|
|
||||||
|
if val_latents.shape == latents.shape:
|
||||||
|
break
|
||||||
|
|
||||||
timesteps_list = [10, 350, 500, 650, 990]
|
timesteps_list = [10, 350, 500, 650, 990]
|
||||||
|
|
||||||
@@ -1357,13 +1377,13 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
for fixed_timesteps in timesteps_list:
|
for fixed_timesteps in timesteps_list:
|
||||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(val_latents, device=val_latents.device)
|
||||||
b_size = latents.shape[0]
|
b_size = val_latents.shape[0]
|
||||||
|
|
||||||
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu")
|
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu")
|
||||||
timesteps = timesteps.long().to(latents.device)
|
timesteps = timesteps.long().to(val_latents.device)
|
||||||
|
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps)
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
@@ -1373,27 +1393,16 @@ class NetworkTrainer:
|
|||||||
noisy_latents.requires_grad_(False),
|
noisy_latents.requires_grad_(False),
|
||||||
timesteps,
|
timesteps,
|
||||||
text_encoder_conds,
|
text_encoder_conds,
|
||||||
batch,
|
val_batch,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
target = noise_scheduler.get_velocity(val_latents, noise, timesteps)
|
||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
# huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
|
||||||
# loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
|
||||||
# if weighting is not None:
|
|
||||||
# loss = loss * weighting
|
|
||||||
# if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
|
||||||
# loss = apply_masked_loss(loss, batch)
|
|
||||||
# loss = loss.mean([1, 2, 3])
|
|
||||||
|
|
||||||
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
|
|
||||||
# loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
||||||
|
|||||||
Reference in New Issue
Block a user