diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index b43dba50..40ba51df 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -804,9 +804,6 @@ class WaveletLoss(nn.Module): # Combine high frequency bands for visualization if combined_hf_pred and combined_hf_target: - combined_hf_pred = self._pad_tensors(combined_hf_pred) - combined_hf_target = self._pad_tensors(combined_hf_target) - combined_hf_pred = torch.cat(combined_hf_pred, dim=1) combined_hf_target = torch.cat(combined_hf_target, dim=1) else: diff --git a/train_network.py b/train_network.py index e1ccdc4a..68123a8e 100644 --- a/train_network.py +++ b/train_network.py @@ -64,6 +64,7 @@ class NetworkTrainer: args: argparse.Namespace, current_loss, avr_loss, + avr_wav_loss, lr_scheduler, lr_descriptions, optimizer=None, @@ -75,6 +76,9 @@ class NetworkTrainer: ): logs = {"loss/current": current_loss, "loss/average": avr_loss} + if avr_wav_loss is not None: + logs['loss/wavelet_average'] = avr_wav_loss + if keys_scaled is not None: logs["max_norm/keys_scaled"] = keys_scaled logs["max_norm/max_key_norm"] = maximum_norm @@ -381,7 +385,7 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Process a batch for the network """ @@ -464,6 +468,7 @@ class NetworkTrainer: huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + wav_loss = None if args.wavelet_loss_alpha: if args.wavelet_loss_rectified_flow: # Calculate flow-based clean estimate using the target @@ -503,7 +508,7 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean() + return loss.mean(), wav_loss def train(self, args): session_id = random.randint(0, 2**32) @@ -1296,8 +1301,11 @@ class NetworkTrainer: train_util.init_trackers(accelerator, args, "network_train") loss_recorder = train_util.LossRecorder() + wav_loss_recorder = train_util.LossRecorder() val_step_loss_recorder = train_util.LossRecorder() + val_step_wav_loss_recorder = train_util.LossRecorder() val_epoch_loss_recorder = train_util.LossRecorder() + val_epoch_wav_loss_recorder = train_util.LossRecorder() if args.wavelet_loss: self.wavelet_loss = WaveletLoss( @@ -1456,7 +1464,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss = self.process_batch( + loss, wav_loss = self.process_batch( batch, text_encoders, unet, @@ -1540,7 +1548,9 @@ class NetworkTrainer: current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) avr_loss: float = loss_recorder.moving_average + avr_wav_loss: float = wav_loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**{**max_mean_logs, **logs}) @@ -1549,6 +1559,7 @@ class NetworkTrainer: args, current_loss, avr_loss, + avr_wav_loss, lr_scheduler, lr_descriptions, optimizer, @@ -1584,7 +1595,7 @@ class NetworkTrainer: args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss = self.process_batch( + loss, wav_loss = self.process_batch( batch, text_encoders, unet, @@ -1604,6 +1615,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} @@ -1620,6 +1632,7 @@ class NetworkTrainer: loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/step_average": val_step_loss_recorder.moving_average, + "loss/validation/step_wavelet_average": val_step_wav_loss_recorder.moving_average, "loss/validation/step_divergence": loss_validation_divergence, } self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) @@ -1662,7 +1675,7 @@ class NetworkTrainer: # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss = self.process_batch( + loss, wav_loss = self.process_batch( batch, text_encoders, unet, @@ -1682,6 +1695,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) + val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep} @@ -1696,10 +1710,12 @@ class NetworkTrainer: if is_tracking: avr_loss: float = val_epoch_loss_recorder.moving_average + avr_wav_loss: float = val_epoch_wav_loss_recorder.moving_average loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average logs = { "loss/validation/epoch_average": avr_loss, "loss/validation/epoch_divergence": loss_validation_divergence, + "loss/validation/epoch_wavelet_average": avr_wav_loss, } self.epoch_logging(accelerator, logs, global_step, epoch + 1)