From c89252101e8e8bd74cb3ab09ae33b548fd828e15 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 16:27:36 -0500 Subject: [PATCH] Add process_batch for train_network --- train_network.py | 211 ++++++++++++++++++----------------------------- 1 file changed, 82 insertions(+), 129 deletions(-) diff --git a/train_network.py b/train_network.py index f9e5debd..387b94b1 100644 --- a/train_network.py +++ b/train_network.py @@ -130,6 +130,75 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * self.vae_scale_factor + b_size = latents.shape[0] + + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): + # Get the text embedding for conditioning + if args.weighted_captions: + text_encoder_conds = get_weighted_text_embeddings( + tokenizers[0], + text_encoders[0], + batch["captions"], + accelerator.device, + args.max_token_length // 75 if args.max_token_length else 1, + clip_skip=args.clip_skip, + ) + else: + text_encoder_conds = self.get_text_cond( + args, accelerator, batch, tokenizers, text_encoders, weight_dtype + ) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) + + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + return loss + + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -777,71 +846,8 @@ class NetworkTrainer: current_step.value = global_step with accelerator.accumulate(network): on_step_start(text_encoder, unet) - - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = True + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -893,7 +899,7 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs) + accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -905,80 +911,27 @@ class NetworkTrainer: with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample() - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) - latents = latents * self.vae_scale_factor - b_size = latents.shape[0] - - # Get the text embedding for conditioning - if args.weighted_captions: - text_encoder_conds = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - text_encoder_conds = self.get_text_cond( - args, accelerator, batch, tokenizers, text_encoders, weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps( - args, noise_scheduler, latents - ) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) + if len(val_dataloader) > 0: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/validation": avr_loss} + logs = {"loss/validation_average": avr_loss} accelerator.log(logs, step=epoch + 1) if args.logging_dir is not None: - logs = {"loss/epoch": loss_recorder.moving_average} + # logs = {"loss/epoch": loss_recorder.moving_average} + logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone()