update FLUX LoRA training

This commit is contained in:
Kohya S
2024-08-10 23:42:05 +09:00
parent 358f13f2c9
commit 8a0f12dde8
7 changed files with 148 additions and 39 deletions

View File

@@ -3186,6 +3186,7 @@ def get_sai_model_spec(
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None,
):
timestamp = time.time()
@@ -3220,6 +3221,7 @@ def get_sai_model_spec(
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
)
return metadata
@@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類L2、Huber、またはsmooth L1、デフォルトはL2",
choices=["l1", "l2", "huber", "smooth_l1"],
help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1、デフォルトはL2",
)
parser.add_argument(
"--huber_schedule",
@@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
):
if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":