update FLUX LoRA training

This commit is contained in:
Kohya S
2024-08-10 23:42:05 +09:00
parent 358f13f2c9
commit 8a0f12dde8
7 changed files with 148 additions and 39 deletions

View File

@@ -59,6 +59,8 @@ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3-medium"
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -66,6 +68,7 @@ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -118,10 +121,11 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: str = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
):
"""
sd3: only supports "m"
sd3: only supports "m", flux: only supports "dev"
"""
# if state_dict is None, hash is not calculated
@@ -140,6 +144,11 @@ def build_metadata(
arch = ARCH_SD3_M
else:
arch = ARCH_SD3_UNKNOWN
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
else:
arch = ARCH_FLUX_1_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -158,7 +167,10 @@ def build_metadata(
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if flux is not None:
# Flux
impl = IMPL_FLUX
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -216,7 +228,7 @@ def build_metadata(
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl or sd3 is not None:
if sdxl or sd3 is not None or flux is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
@@ -227,7 +239,9 @@ def build_metadata(
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON

View File

@@ -63,11 +63,11 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
l_pooled = None
if t5xxl is not None and t5_tokens is not None:
# t5_out is [1, max length, 4096]
# t5_out is [b, max length, 4096]
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None

View File

@@ -3186,6 +3186,7 @@ def get_sai_model_spec(
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None,
):
timestamp = time.time()
@@ -3220,6 +3221,7 @@ def get_sai_model_spec(
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
)
return metadata
@@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類L2、Huber、またはsmooth L1、デフォルトはL2",
choices=["l1", "l2", "huber", "smooth_l1"],
help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1、デフォルトはL2",
)
parser.add_argument(
"--huber_schedule",
@@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
):
if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "l1":
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":