mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Change val latent loss compare
This commit is contained in:
@@ -1350,6 +1350,8 @@ class NetworkTrainer:
|
||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"):
|
||||
|
||||
val_latents = None
|
||||
|
||||
while True:
|
||||
val_batch = next(cyclic_val_dataloader)
|
||||
|
||||
@@ -1371,19 +1373,22 @@ class NetworkTrainer:
|
||||
if val_latents.shape == latents.shape:
|
||||
break
|
||||
|
||||
if val_latents is not None:
|
||||
del val_latents
|
||||
|
||||
timesteps_list = [10, 350, 500, 650, 990]
|
||||
|
||||
val_loss = 0.0
|
||||
|
||||
for fixed_timesteps in timesteps_list:
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
noise = torch.randn_like(val_latents, device=val_latents.device)
|
||||
b_size = val_latents.shape[0]
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
b_size = latents.shape[0]
|
||||
|
||||
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu")
|
||||
timesteps = timesteps.long().to(val_latents.device)
|
||||
timesteps = timesteps.long().to(latents.device)
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
with accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
@@ -1399,7 +1404,7 @@ class NetworkTrainer:
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(val_latents, noise, timesteps)
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
|
||||
Reference in New Issue
Block a user