mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix: refactor huber-loss calculation in multiple training scripts
This commit is contained in:
@@ -5869,7 +5869,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
|
||||
return None
|
||||
|
||||
b_size = timesteps.shape[0]
|
||||
if args.huber_schedule == "exponential":
|
||||
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
|
||||
@@ -5890,22 +5893,20 @@ def get_huber_threshold(args, timesteps: torch.Tensor, noise_scheduler) -> torch
|
||||
|
||||
|
||||
def conditional_loss(
|
||||
args, model_pred: torch.Tensor, target: torch.Tensor, timesteps: torch.Tensor, reduction: str, noise_scheduler
|
||||
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None
|
||||
):
|
||||
if args.loss_type == "l2":
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif args.loss_type == "l1":
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif args.loss_type == "huber":
|
||||
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
|
||||
elif loss_type == "huber":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
elif args.loss_type == "smooth_l1":
|
||||
huber_c = get_huber_threshold(args, timesteps, noise_scheduler)
|
||||
elif loss_type == "smooth_l1":
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
@@ -5913,7 +5914,7 @@ def conditional_loss(
|
||||
elif reduction == "sum":
|
||||
loss = torch.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported Loss Type: {args.loss_type}")
|
||||
raise NotImplementedError(f"Unsupported Loss Type: {loss_type}")
|
||||
return loss
|
||||
|
||||
|
||||
@@ -5923,7 +5924,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
|
||||
names.append("unet")
|
||||
names.append("text_encoder1")
|
||||
names.append("text_encoder2")
|
||||
names.append("text_encoder3") # SD3
|
||||
names.append("text_encoder3") # SD3
|
||||
|
||||
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user