From d0ce8674987dc4097b5d792bd2e38381af2a9379 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 6 May 2025 00:29:27 -0400 Subject: [PATCH] Fix loss/wavelet metric --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 5515765d..5b2d2751 100644 --- a/train_network.py +++ b/train_network.py @@ -1559,7 +1559,7 @@ class NetworkTrainer: current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) - wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + wav_loss_recorder.add(epoch=epoch, step=step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) avr_loss: float = loss_recorder.moving_average avr_wav_loss: float = wav_loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} @@ -1627,7 +1627,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep} @@ -1707,7 +1707,7 @@ class NetworkTrainer: current_loss = loss.detach().item() val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss) - val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0) + val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0) val_progress_bar.update(1) val_progress_bar.set_postfix( {"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}