diff --git a/train_network.py b/train_network.py index 61e6369a..5a80d825 100644 --- a/train_network.py +++ b/train_network.py @@ -318,8 +318,27 @@ class NetworkTrainer: # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: - + def process_batch( + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, + tokenize_strategy: strategy_sd.SdTokenizeStrategy, + is_train=True, + train_text_encoder=True, + train_unet=True + ) -> torch.Tensor: + """ + Process a batch for the network + """ with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -334,7 +353,6 @@ class NetworkTrainer: latents = self.shift_scale_latents(args, latents) - text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: @@ -371,13 +389,6 @@ class NetworkTrainer: if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] - batch_size = latents.shape[0] - - - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - # sample noise, call unet, get target noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( args, @@ -1288,7 +1299,23 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet) + loss = self.process_batch(batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=train_text_encoder, + train_unet=train_unet + ) + accelerator.backward(loss) if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually @@ -1366,12 +1393,26 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) - + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) + val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average }) - + if is_tracking: logs = {"loss/current_val_loss": loss.detach().item()} # accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step) @@ -1397,7 +1438,21 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) + loss = self.process_batch( + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=False + ) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)