Add wavelet loss recording

This commit is contained in:
rockerBOO
2025-04-11 19:27:16 -04:00
parent 6d42b95e2b
commit f553b7bf31
2 changed files with 21 additions and 8 deletions

View File

@@ -804,9 +804,6 @@ class WaveletLoss(nn.Module):
# Combine high frequency bands for visualization
if combined_hf_pred and combined_hf_target:
combined_hf_pred = self._pad_tensors(combined_hf_pred)
combined_hf_target = self._pad_tensors(combined_hf_target)
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
combined_hf_target = torch.cat(combined_hf_target, dim=1)
else:

View File

@@ -64,6 +64,7 @@ class NetworkTrainer:
args: argparse.Namespace,
current_loss,
avr_loss,
avr_wav_loss,
lr_scheduler,
lr_descriptions,
optimizer=None,
@@ -75,6 +76,9 @@ class NetworkTrainer:
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if avr_wav_loss is not None:
logs['loss/wavelet_average'] = avr_wav_loss
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/max_key_norm"] = maximum_norm
@@ -381,7 +385,7 @@ class NetworkTrainer:
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Process a batch for the network
"""
@@ -464,6 +468,7 @@ class NetworkTrainer:
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
wav_loss = None
if args.wavelet_loss_alpha:
if args.wavelet_loss_rectified_flow:
# Calculate flow-based clean estimate using the target
@@ -503,7 +508,7 @@ class NetworkTrainer:
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
return loss.mean(), wav_loss
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -1296,8 +1301,11 @@ class NetworkTrainer:
train_util.init_trackers(accelerator, args, "network_train")
loss_recorder = train_util.LossRecorder()
wav_loss_recorder = train_util.LossRecorder()
val_step_loss_recorder = train_util.LossRecorder()
val_step_wav_loss_recorder = train_util.LossRecorder()
val_epoch_loss_recorder = train_util.LossRecorder()
val_epoch_wav_loss_recorder = train_util.LossRecorder()
if args.wavelet_loss:
self.wavelet_loss = WaveletLoss(
@@ -1456,7 +1464,7 @@ class NetworkTrainer:
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch(
loss, wav_loss = self.process_batch(
batch,
text_encoders,
unet,
@@ -1540,7 +1548,9 @@ class NetworkTrainer:
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
avr_loss: float = loss_recorder.moving_average
avr_wav_loss: float = wav_loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**{**max_mean_logs, **logs})
@@ -1549,6 +1559,7 @@ class NetworkTrainer:
args,
current_loss,
avr_loss,
avr_wav_loss,
lr_scheduler,
lr_descriptions,
optimizer,
@@ -1584,7 +1595,7 @@ class NetworkTrainer:
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
loss = self.process_batch(
loss, wav_loss = self.process_batch(
batch,
text_encoders,
unet,
@@ -1604,6 +1615,7 @@ class NetworkTrainer:
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
@@ -1620,6 +1632,7 @@ class NetworkTrainer:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_wavelet_average": val_step_wav_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
@@ -1662,7 +1675,7 @@ class NetworkTrainer:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch(
loss, wav_loss = self.process_batch(
batch,
text_encoders,
unet,
@@ -1682,6 +1695,7 @@ class NetworkTrainer:
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
@@ -1696,10 +1710,12 @@ class NetworkTrainer:
if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
avr_wav_loss: float = val_epoch_wav_loss_recorder.moving_average
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
"loss/validation/epoch_wavelet_average": avr_wav_loss,
}
self.epoch_logging(accelerator, logs, global_step, epoch + 1)