Implement pseudo Huber loss for Flux and SD3

This commit is contained in:
recris
2024-11-27 18:11:51 +00:00
parent 2a61fc0784
commit 420a180d93
15 changed files with 76 additions and 61 deletions

View File

@@ -3905,7 +3905,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
help="The Huber loss decay parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--huber_scale",
type=float,
default=1.0,
help="The Huber loss scale parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 1.0 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
@@ -5821,29 +5828,10 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = torch.exp(-alpha * timesteps)
elif args.huber_schedule == "snr":
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = torch.full((b_size,), args.huber_c)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
huber_c = huber_c.to(device)
elif args.loss_type == "l2":
huber_c = None # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long().to(device)
return timesteps, huber_c
def get_timesteps(min_timestep, max_timestep, b_size, device):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
timesteps = timesteps.long()
return timesteps
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
@@ -5865,7 +5853,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -5878,24 +5866,46 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps, huber_c
return noise, noisy_latents, timesteps
def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
b_size = timesteps.shape[0]
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, 'alphas_cumprod'):
raise NotImplementedError(f"Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
result = torch.full((b_size,), args.huber_c * args.huber_scale, device=timesteps.device)
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
return result
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor]
args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
):
if loss_type == "l2":
if args.loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
elif args.loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
elif args.loss_type == "huber":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
elif args.loss_type == "smooth_l1":
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
huber_c = huber_c.view(-1, 1, 1, 1)
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
@@ -5903,7 +5913,7 @@ def conditional_loss(
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
return loss