Fix loss/wavelet metric

This commit is contained in:
rockerBOO
2025-05-06 00:29:27 -04:00
parent 869dc000d9
commit d0ce867498

View File

@@ -1559,7 +1559,7 @@ class NetworkTrainer:
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
wav_loss_recorder.add(epoch=epoch, step=step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0)
avr_loss: float = loss_recorder.moving_average
avr_wav_loss: float = wav_loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -1627,7 +1627,7 @@ class NetworkTrainer:
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
@@ -1707,7 +1707,7 @@ class NetworkTrainer:
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=metrics['loss/wavelet'] if 'loss/wavelet' in metrics else 0.0)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}