From befbec5335ed1f8018d22b65993b376571ea2989 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Mon, 11 Mar 2024 18:47:04 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/train_network.py b/train_network.py index f0f27ea7..cbc107b6 100644 --- a/train_network.py +++ b/train_network.py @@ -174,7 +174,7 @@ class NetworkTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - for fixed_timesteps in timesteps_list: + for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'): with torch.set_grad_enabled(is_train), accelerator.autocast(): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] @@ -184,16 +184,16 @@ class NetworkTrainer: noise_pred = self.call_unet( args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype ) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + total_loss += loss average_loss = total_loss / len(timesteps_list) return average_loss @@ -985,7 +985,7 @@ class NetworkTrainer: if args.validation_every_n_step is not None: if global_step % (args.validation_every_n_step) == 0: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -994,10 +994,12 @@ class NetworkTrainer: batch = next(cyclic_val_dataloader) loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) total_loss += loss.detach().item() - current_loss = total_loss / args.validation_batches - val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + current_loss = total_loss / args.validation_batches + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) if args.logging_dir is not None: + logs = {"loss/current_val_loss": current_loss} + accelerator.log(logs, step=global_step) avr_loss: float = val_loss_recorder.moving_average logs = {"loss/average_val_loss": avr_loss} accelerator.log(logs, step=global_step) @@ -1011,7 +1013,7 @@ class NetworkTrainer: if args.validation_every_n_step is None: if len(val_dataloader) > 0: - print("Validating バリデーション処理...") + print(f"\nValidating バリデーション処理...") total_loss = 0.0 with torch.no_grad(): validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader) @@ -1025,7 +1027,7 @@ class NetworkTrainer: if args.logging_dir is not None: avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/val_epoch_average": avr_loss} + logs = {"loss/epoch_val_average": avr_loss} accelerator.log(logs, step=epoch + 1) accelerator.wait_for_everyone()