diff --git a/train_network.py b/train_network.py index 5515765d..5b2d2751 100644 --- a/train_network.py +++ b/train_network.py @@ -1559,7 +1559,7 @@ class NetworkTrainer: current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + wav_loss_recorder.add(epoch=epoch, step=step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) avr_loss: float = loss_recorder.moving_average avr_wav_loss: float = wav_loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} @@ -1627,7 +1627,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} @@ -1707,7 +1707,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}