Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results)

This commit is contained in:
gesen2egee
2024-08-02 13:15:21 +08:00
committed by GitHub
parent fde8026c2d
commit 31507b9901

View File

@@ -135,7 +135,7 @@ class NetworkTrainer:
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
total_loss = 0.0 total_loss = 0.0
timesteps_list = [10, 350, 500, 650, 990] timesteps_list = [10, 350, 500, 650, 990]
@@ -153,7 +153,7 @@ class NetworkTrainer:
latents = latents * self.vae_scale_factor latents = latents * self.vae_scale_factor
b_size = latents.shape[0] b_size = latents.shape[0]
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): with torch.set_grad_enabled(False), accelerator.autocast():
# Get the text embedding for conditioning # Get the text embedding for conditioning
if args.weighted_captions: if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings( text_encoder_conds = get_weighted_text_embeddings(
@@ -173,7 +173,7 @@ class NetworkTrainer:
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
for fixed_timesteps in timesteps_list: for fixed_timesteps in timesteps_list:
with torch.set_grad_enabled(is_train), accelerator.autocast(): with torch.set_grad_enabled(False), accelerator.autocast():
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0] b_size = latents.shape[0]
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
@@ -189,6 +189,7 @@ class NetworkTrainer:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
total_loss += loss total_loss += loss
@@ -885,8 +886,7 @@ class NetworkTrainer:
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
current_step.value = global_step current_step.value = global_step
with accelerator.accumulate(training_model): with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet) on_step_start(text_encoder, unet)
is_train = True
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else: else:
@@ -911,7 +911,7 @@ class NetworkTrainer:
# print(f"set multiplier: {multipliers}") # print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers) accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning # Get the text embedding for conditioning
if args.weighted_captions: if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings( text_encoder_conds = get_weighted_text_embeddings(
@@ -941,7 +941,7 @@ class NetworkTrainer:
t.requires_grad_(True) t.requires_grad_(True)
# Predict the noise residual # Predict the noise residual
with torch.set_grad_enabled(is_train), accelerator.autocast(): with accelerator.autocast():
noise_pred = self.call_unet( noise_pred = self.call_unet(
args, args,
accelerator, accelerator,
@@ -1040,10 +1040,9 @@ class NetworkTrainer:
total_loss = 0.0 total_loss = 0.0
with torch.no_grad(): with torch.no_grad():
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
for val_step in tqdm(range(validation_steps), desc='Validation Steps'): for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
is_train = False
batch = next(cyclic_val_dataloader) batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item() total_loss += loss.detach().item()
current_loss = total_loss / validation_steps current_loss = total_loss / validation_steps
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)