Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228)

* add huber loss and huber_c compute to train_util

* add reduction modes

* add huber_c retrieval from timestep getter

* move get timesteps and huber to own function

* add conditional loss to all training scripts

* add cond loss to train network

* add (scheduled) huber_loss to args

* fixup twice timesteps getting

* PHL-schedule should depend on noise scheduler's num timesteps

* *2 multiplier to huber loss cause of 1/2 a^2 conv.

The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another

* add option for smooth l1 (huber / delta)

* unify huber scheduling

* add snr huber scheduler

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
kabachuha
2024-04-07 07:54:21 +03:00
committed by GitHub
parent 089727b5ee
commit 90b18795fc
10 changed files with 95 additions and 29 deletions

View File

@@ -3236,6 +3236,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する1~1000で指定、省略時はデフォルト値(1000)",
)
parser.add_argument(
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_schedule",
type=str,
default="exponential",
choices=["constant", "exponential", "snr"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
)
parser.add_argument(
"--lowram",
@@ -4842,6 +4862,38 @@ def save_sd_model_on_train_end_common(
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
#TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
timestep = timesteps.item()
if args.huber_schedule == "exponential":
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')
timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == 'l2':
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
else:
raise NotImplementedError(f'Unknown loss type {args.loss_type}')
timesteps = timesteps.long()
return timesteps, huber_c
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
@@ -4862,8 +4914,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -4876,8 +4927,28 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps
return noise, noisy_latents, timesteps, huber_c
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1):
if loss_type == 'l2':
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == 'huber':
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == 'smooth_l1':
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
return loss
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names = []