mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add wavelet loss recording
This commit is contained in:
@@ -804,9 +804,6 @@ class WaveletLoss(nn.Module):
|
||||
|
||||
# Combine high frequency bands for visualization
|
||||
if combined_hf_pred and combined_hf_target:
|
||||
combined_hf_pred = self._pad_tensors(combined_hf_pred)
|
||||
combined_hf_target = self._pad_tensors(combined_hf_target)
|
||||
|
||||
combined_hf_pred = torch.cat(combined_hf_pred, dim=1)
|
||||
combined_hf_target = torch.cat(combined_hf_target, dim=1)
|
||||
else:
|
||||
|
||||
@@ -64,6 +64,7 @@ class NetworkTrainer:
|
||||
args: argparse.Namespace,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
avr_wav_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer=None,
|
||||
@@ -75,6 +76,9 @@ class NetworkTrainer:
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if avr_wav_loss is not None:
|
||||
logs['loss/wavelet_average'] = avr_wav_loss
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
@@ -381,7 +385,7 @@ class NetworkTrainer:
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
@@ -464,6 +468,7 @@ class NetworkTrainer:
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
wav_loss = None
|
||||
if args.wavelet_loss_alpha:
|
||||
if args.wavelet_loss_rectified_flow:
|
||||
# Calculate flow-based clean estimate using the target
|
||||
@@ -503,7 +508,7 @@ class NetworkTrainer:
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean()
|
||||
return loss.mean(), wav_loss
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1296,8 +1301,11 @@ class NetworkTrainer:
|
||||
train_util.init_trackers(accelerator, args, "network_train")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
wav_loss_recorder = train_util.LossRecorder()
|
||||
val_step_loss_recorder = train_util.LossRecorder()
|
||||
val_step_wav_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_loss_recorder = train_util.LossRecorder()
|
||||
val_epoch_wav_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
if args.wavelet_loss:
|
||||
self.wavelet_loss = WaveletLoss(
|
||||
@@ -1456,7 +1464,7 @@ class NetworkTrainer:
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, wav_loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1540,7 +1548,9 @@ class NetworkTrainer:
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
wav_loss_recorder.add(epoch=epoch, step=step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
avr_wav_loss: float = wav_loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
@@ -1549,6 +1559,7 @@ class NetworkTrainer:
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
avr_wav_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
@@ -1584,7 +1595,7 @@ class NetworkTrainer:
|
||||
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, wav_loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1604,6 +1615,7 @@ class NetworkTrainer:
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_step_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
||||
@@ -1620,6 +1632,7 @@ class NetworkTrainer:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_wavelet_average": val_step_wav_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
|
||||
@@ -1662,7 +1675,7 @@ class NetworkTrainer:
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, wav_loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1682,6 +1695,7 @@ class NetworkTrainer:
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_epoch_wav_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=wav_loss.detach().item() if wav_loss is not None else 0.0)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
||||
@@ -1696,10 +1710,12 @@ class NetworkTrainer:
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
avr_wav_loss: float = val_epoch_wav_loss_recorder.moving_average
|
||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
"loss/validation/epoch_wavelet_average": avr_wav_loss,
|
||||
}
|
||||
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user