Implement pseudo Huber loss for Flux and SD3

This commit is contained in:
recris
2024-11-27 18:11:51 +00:00
parent 2a61fc0784
commit 420a180d93
15 changed files with 76 additions and 61 deletions

View File

@@ -192,7 +192,7 @@ class NetworkTrainer:
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
@@ -244,7 +244,7 @@ class NetworkTrainer:
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, target, timesteps, huber_c, None
return noise_pred, target, timesteps, None
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
if args.min_snr_gamma:
@@ -806,6 +806,7 @@ class NetworkTrainer:
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_scale": args.huber_scale,
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
@@ -1193,7 +1194,7 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -1207,7 +1208,7 @@ class NetworkTrainer:
)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
)
if weighting is not None:
loss = loss * weighting