From cb89e0284e1a25b41401861107159e6b943ee387 Mon Sep 17 00:00:00 2001 From: Hina Chen Date: Sat, 28 Dec 2024 11:57:04 +0800 Subject: [PATCH] Change val latent loss compare --- train_network.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 6bce9e96..7276d5dc 100644 --- a/train_network.py +++ b/train_network.py @@ -1350,6 +1350,8 @@ class NetworkTrainer: validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) for val_step in tqdm(range(validation_steps), desc="Validation Steps バリデーションテップ"): + val_latents = None + while True: val_batch = next(cyclic_val_dataloader) @@ -1371,19 +1373,22 @@ class NetworkTrainer: if val_latents.shape == latents.shape: break + if val_latents is not None: + del val_latents + timesteps_list = [10, 350, 500, 650, 990] val_loss = 0.0 for fixed_timesteps in timesteps_list: with torch.set_grad_enabled(False), accelerator.autocast(): - noise = torch.randn_like(val_latents, device=val_latents.device) - b_size = val_latents.shape[0] + noise = torch.randn_like(latents, device=latents.device) + b_size = latents.shape[0] timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device="cpu") - timesteps = timesteps.long().to(val_latents.device) + timesteps = timesteps.long().to(latents.device) - noisy_latents = noise_scheduler.add_noise(val_latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) with accelerator.autocast(): noise_pred = self.call_unet( @@ -1399,7 +1404,7 @@ class NetworkTrainer: if args.v_parameterization: # v-parameterization training - target = noise_scheduler.get_velocity(val_latents, noise, timesteps) + target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise