mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update FLUX LoRA training
This commit is contained in:
@@ -3186,6 +3186,7 @@ def get_sai_model_spec(
|
||||
textual_inversion: bool,
|
||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||
sd3: str = None,
|
||||
flux: str = None,
|
||||
):
|
||||
timestamp = time.time()
|
||||
|
||||
@@ -3220,6 +3221,7 @@ def get_sai_model_spec(
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
sd3=sd3,
|
||||
flux=flux,
|
||||
)
|
||||
return metadata
|
||||
|
||||
@@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--loss_type",
|
||||
type=str,
|
||||
default="l2",
|
||||
choices=["l2", "huber", "smooth_l1"],
|
||||
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
|
||||
choices=["l1", "l2", "huber", "smooth_l1"],
|
||||
help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--huber_schedule",
|
||||
@@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
def conditional_loss(
|
||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
|
||||
):
|
||||
|
||||
if loss_type == "l2":
|
||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "l1":
|
||||
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||
elif loss_type == "huber":
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
|
||||
Reference in New Issue
Block a user