mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Implement pseudo Huber loss for Flux and SD3
This commit is contained in:
@@ -192,7 +192,7 @@ class NetworkTrainer:
|
||||
):
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
@@ -244,7 +244,7 @@ class NetworkTrainer:
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, target, timesteps, huber_c, None
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
if args.min_snr_gamma:
|
||||
@@ -806,6 +806,7 @@ class NetworkTrainer:
|
||||
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
|
||||
"ss_loss_type": args.loss_type,
|
||||
"ss_huber_schedule": args.huber_schedule,
|
||||
"ss_huber_scale": args.huber_scale,
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": bool(args.fp8_base),
|
||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||
@@ -1193,7 +1194,7 @@ class NetworkTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -1207,7 +1208,7 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler
|
||||
)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
|
||||
Reference in New Issue
Block a user