mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
update FLUX LoRA training
This commit is contained in:
29
README.md
29
README.md
@@ -2,24 +2,41 @@ This repository contains training, generation and utility scripts for Stable Dif
|
|||||||
|
|
||||||
## FLUX.1 LoRA training (WIP)
|
## FLUX.1 LoRA training (WIP)
|
||||||
|
|
||||||
__Aug 9, 2024__:
|
This feature is experimental. The options and the training script may change in the future. Please let us know if you have any idea to improve the training.
|
||||||
|
|
||||||
|
Aug 10, 2024: LoRA key prefix is changed to `lora_unet` from `lora_flex` to make it compatible with ComfyUI.
|
||||||
|
|
||||||
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
|
Please update PyTorch to 2.4.0. We have tested with PyTorch 2.4.0 with CUDA 12.4. We also updated `accelerate` to 0.33.0 just to be safe.
|
||||||
|
|
||||||
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options.
|
We have added a new training script for LoRA training. The script is `flux_train_network.py`. See `--help` for options. Sample command is below. It will work with 24GB VRAM GPUs.
|
||||||
|
|
||||||
```
|
```
|
||||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name
|
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --network_train_unet_only --cache_text_encoder_outputs --cache_text_encoder_outputs_to_disk --fp8_base --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir path/to/output/dir --output_name flux-lora-name --timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0 --loss_type l2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
LoRAs for Text Encoders are not tested yet.
|
||||||
|
|
||||||
|
We have added some new options (Aug 10, 2024): `--time_sampling`, `--sigmoid_scale`, `--model_prediction_type` and `--discrete_flow_shift`. The options are as follows:
|
||||||
|
|
||||||
|
- `--timestep_sampling` is the method to sample timesteps (0-1): `sigma` (sigma-based, same as SD3), `uniform` (uniform random), or `sigmoid` (sigmoid of random normal, same as x-flux).
|
||||||
|
- `--sigmoid_scale` is the scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). The default is 1.0. Larger values will make the sampling more uniform.
|
||||||
|
- `--model_prediction_type` is how to interpret and process the model prediction: `raw` (use as is, same as x-flux), `additive` (add to noisy input), `sigma_scaled` (apply sigma scaling, same as SD3).
|
||||||
|
- `--discrete_flow_shift` is the discrete flow shift for the Euler Discrete Scheduler, default is 3.0 (same as SD3).
|
||||||
|
|
||||||
|
`--loss_type` may be useful for FLUX.1 training. The default is `l2`.
|
||||||
|
|
||||||
|
In our experiments, `--timestep_sampling sigma --model_prediction_type raw --discrete_flow_shift 1.0` with `--loss_type l2` seems to work better than the default (SD3) settings. The multiplier of LoRA should be adjusted. Other settings may work better, so please try different settings.
|
||||||
|
|
||||||
|
We also not sure how many epochs are needed for convergence, and how the learning rate should be adjusted.
|
||||||
|
|
||||||
|
The trained LoRA model can be used with ComfyUI.
|
||||||
|
|
||||||
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||||
|
|
||||||
```
|
```
|
||||||
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors
|
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
Unfortnately the training result is not good. Please let us know if you have any idea to improve the training.
|
|
||||||
|
|
||||||
## SD3 training
|
## SD3 training
|
||||||
|
|
||||||
SD3 training is done with `sd3_train.py`.
|
SD3 training is done with `sd3_train.py`.
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
|
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||||
return noise_scheduler
|
return noise_scheduler
|
||||||
|
|
||||||
@@ -211,21 +211,32 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
bsz = latents.shape[0]
|
bsz = latents.shape[0]
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||||
# for weighting schemes where we sample timesteps non-uniformly
|
# Simple random t-based noise sampling
|
||||||
u = compute_density_for_timestep_sampling(
|
if args.timestep_sampling == "sigmoid":
|
||||||
weighting_scheme=args.weighting_scheme,
|
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||||
batch_size=bsz,
|
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device))
|
||||||
logit_mean=args.logit_mean,
|
else:
|
||||||
logit_std=args.logit_std,
|
t = torch.rand((bsz,), device=accelerator.device)
|
||||||
mode_scale=args.mode_scale,
|
timesteps = t * 1000.0
|
||||||
)
|
t = t.view(-1, 1, 1, 1)
|
||||||
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
noisy_model_input = (1 - t) * latents + t * noise
|
||||||
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
|
else:
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
# for weighting schemes where we sample timesteps non-uniformly
|
||||||
|
u = compute_density_for_timestep_sampling(
|
||||||
|
weighting_scheme=args.weighting_scheme,
|
||||||
|
batch_size=bsz,
|
||||||
|
logit_mean=args.logit_mean,
|
||||||
|
logit_std=args.logit_std,
|
||||||
|
mode_scale=args.mode_scale,
|
||||||
|
)
|
||||||
|
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
||||||
|
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
|
||||||
|
|
||||||
# Add noise according to flow matching.
|
# Add noise according to flow matching.
|
||||||
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
||||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||||
|
|
||||||
# pack latents and get img_ids
|
# pack latents and get img_ids
|
||||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||||
@@ -264,11 +275,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# unpack latents
|
# unpack latents
|
||||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||||
|
|
||||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
if args.model_prediction_type == "raw":
|
||||||
|
# use model_pred as is
|
||||||
|
weighting = None
|
||||||
|
elif args.model_prediction_type == "additive":
|
||||||
|
# add the model_pred to the noisy_model_input
|
||||||
|
model_pred = model_pred + noisy_model_input
|
||||||
|
weighting = None
|
||||||
|
elif args.model_prediction_type == "sigma_scaled":
|
||||||
|
# apply sigma scaling
|
||||||
|
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||||
|
|
||||||
# these weighting schemes use a uniform timestep sampling
|
# these weighting schemes use a uniform timestep sampling
|
||||||
# and instead post-weight the loss
|
# and instead post-weight the loss
|
||||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||||
|
|
||||||
# flow matching loss: this is different from SD3
|
# flow matching loss: this is different from SD3
|
||||||
target = noise - latents
|
target = noise - latents
|
||||||
@@ -278,6 +298,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def get_sai_model_spec(self, args):
|
||||||
|
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
||||||
|
|
||||||
|
def update_metadata(self, metadata, args):
|
||||||
|
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
|
||||||
|
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||||
|
metadata["ss_logit_mean"] = args.logit_mean
|
||||||
|
metadata["ss_logit_std"] = args.logit_std
|
||||||
|
metadata["ss_mode_scale"] = args.mode_scale
|
||||||
|
metadata["ss_guidance_scale"] = args.guidance_scale
|
||||||
|
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||||
|
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||||
|
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||||
|
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = train_network.setup_parser()
|
parser = train_network.setup_parser()
|
||||||
@@ -318,6 +353,34 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=3.5,
|
default=3.5,
|
||||||
help="the FLUX.1 dev variant is a guidance distilled model",
|
help="the FLUX.1 dev variant is a guidance distilled model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--timestep_sampling",
|
||||||
|
choices=["sigma", "uniform", "sigmoid"],
|
||||||
|
default="sigma",
|
||||||
|
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sigmoid_scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_prediction_type",
|
||||||
|
choices=["raw", "additive", "sigma_scaled"],
|
||||||
|
default="sigma_scaled",
|
||||||
|
help="How to interpret and process the model prediction: "
|
||||||
|
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
||||||
|
" / モデル予測の解釈と処理方法:"
|
||||||
|
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--discrete_flow_shift",
|
||||||
|
type=float,
|
||||||
|
default=3.0,
|
||||||
|
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
|||||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||||
ARCH_SD3_M = "stable-diffusion-3-medium"
|
ARCH_SD3_M = "stable-diffusion-3-medium"
|
||||||
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||||
|
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||||
|
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||||
|
|
||||||
ADAPTER_LORA = "lora"
|
ADAPTER_LORA = "lora"
|
||||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||||
@@ -66,6 +68,7 @@ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
|||||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||||
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
|
||||||
IMPL_DIFFUSERS = "diffusers"
|
IMPL_DIFFUSERS = "diffusers"
|
||||||
|
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||||
|
|
||||||
PRED_TYPE_EPSILON = "epsilon"
|
PRED_TYPE_EPSILON = "epsilon"
|
||||||
PRED_TYPE_V = "v"
|
PRED_TYPE_V = "v"
|
||||||
@@ -118,10 +121,11 @@ def build_metadata(
|
|||||||
merged_from: Optional[str] = None,
|
merged_from: Optional[str] = None,
|
||||||
timesteps: Optional[Tuple[int, int]] = None,
|
timesteps: Optional[Tuple[int, int]] = None,
|
||||||
clip_skip: Optional[int] = None,
|
clip_skip: Optional[int] = None,
|
||||||
sd3: str = None,
|
sd3: Optional[str] = None,
|
||||||
|
flux: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
sd3: only supports "m"
|
sd3: only supports "m", flux: only supports "dev"
|
||||||
"""
|
"""
|
||||||
# if state_dict is None, hash is not calculated
|
# if state_dict is None, hash is not calculated
|
||||||
|
|
||||||
@@ -140,6 +144,11 @@ def build_metadata(
|
|||||||
arch = ARCH_SD3_M
|
arch = ARCH_SD3_M
|
||||||
else:
|
else:
|
||||||
arch = ARCH_SD3_UNKNOWN
|
arch = ARCH_SD3_UNKNOWN
|
||||||
|
elif flux is not None:
|
||||||
|
if flux == "dev":
|
||||||
|
arch = ARCH_FLUX_1_DEV
|
||||||
|
else:
|
||||||
|
arch = ARCH_FLUX_1_UNKNOWN
|
||||||
elif v2:
|
elif v2:
|
||||||
if v_parameterization:
|
if v_parameterization:
|
||||||
arch = ARCH_SD_V2_768_V
|
arch = ARCH_SD_V2_768_V
|
||||||
@@ -158,7 +167,10 @@ def build_metadata(
|
|||||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||||
|
|
||||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
if flux is not None:
|
||||||
|
# Flux
|
||||||
|
impl = IMPL_FLUX
|
||||||
|
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||||
impl = IMPL_STABILITY_AI
|
impl = IMPL_STABILITY_AI
|
||||||
else:
|
else:
|
||||||
@@ -216,7 +228,7 @@ def build_metadata(
|
|||||||
reso = (reso[0], reso[0])
|
reso = (reso[0], reso[0])
|
||||||
else:
|
else:
|
||||||
# resolution is defined in dataset, so use default
|
# resolution is defined in dataset, so use default
|
||||||
if sdxl or sd3 is not None:
|
if sdxl or sd3 is not None or flux is not None:
|
||||||
reso = 1024
|
reso = 1024
|
||||||
elif v2 and v_parameterization:
|
elif v2 and v_parameterization:
|
||||||
reso = 768
|
reso = 768
|
||||||
@@ -227,7 +239,9 @@ def build_metadata(
|
|||||||
|
|
||||||
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
|
||||||
|
|
||||||
if v_parameterization:
|
if flux is not None:
|
||||||
|
del metadata["modelspec.prediction_type"]
|
||||||
|
elif v_parameterization:
|
||||||
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
metadata["modelspec.prediction_type"] = PRED_TYPE_V
|
||||||
else:
|
else:
|
||||||
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
|
||||||
|
|||||||
@@ -63,11 +63,11 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
l_pooled = None
|
l_pooled = None
|
||||||
|
|
||||||
if t5xxl is not None and t5_tokens is not None:
|
if t5xxl is not None and t5_tokens is not None:
|
||||||
# t5_out is [1, max length, 4096]
|
# t5_out is [b, max length, 4096]
|
||||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
|
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), return_dict=False, output_hidden_states=True)
|
||||||
if apply_t5_attn_mask:
|
if apply_t5_attn_mask:
|
||||||
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||||
txt_ids = torch.zeros(1, t5_out.shape[1], 3, device=t5_out.device)
|
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
|
||||||
else:
|
else:
|
||||||
t5_out = None
|
t5_out = None
|
||||||
txt_ids = None
|
txt_ids = None
|
||||||
|
|||||||
@@ -3186,6 +3186,7 @@ def get_sai_model_spec(
|
|||||||
textual_inversion: bool,
|
textual_inversion: bool,
|
||||||
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
|
||||||
sd3: str = None,
|
sd3: str = None,
|
||||||
|
flux: str = None,
|
||||||
):
|
):
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
|
|
||||||
@@ -3220,6 +3221,7 @@ def get_sai_model_spec(
|
|||||||
timesteps=timesteps,
|
timesteps=timesteps,
|
||||||
clip_skip=args.clip_skip, # None or int
|
clip_skip=args.clip_skip, # None or int
|
||||||
sd3=sd3,
|
sd3=sd3,
|
||||||
|
flux=flux,
|
||||||
)
|
)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@@ -3642,8 +3644,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
"--loss_type",
|
"--loss_type",
|
||||||
type=str,
|
type=str,
|
||||||
default="l2",
|
default="l2",
|
||||||
choices=["l2", "huber", "smooth_l1"],
|
choices=["l1", "l2", "huber", "smooth_l1"],
|
||||||
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
|
help="The type of loss function to use (L1, L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L1、L2、Huber、またはsmooth L1)、デフォルトはL2",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--huber_schedule",
|
"--huber_schedule",
|
||||||
@@ -5359,9 +5361,10 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
|||||||
def conditional_loss(
|
def conditional_loss(
|
||||||
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
|
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
|
||||||
):
|
):
|
||||||
|
|
||||||
if loss_type == "l2":
|
if loss_type == "l2":
|
||||||
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
|
||||||
|
elif loss_type == "l1":
|
||||||
|
loss = torch.nn.functional.l1_loss(model_pred, target, reduction=reduction)
|
||||||
elif loss_type == "huber":
|
elif loss_type == "huber":
|
||||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
|
|||||||
@@ -316,7 +316,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
|||||||
class LoRANetwork(torch.nn.Module):
|
class LoRANetwork(torch.nn.Module):
|
||||||
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_FLUX = "lora_flux"
|
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||||
|
|
||||||
|
|||||||
@@ -226,6 +226,12 @@ class NetworkTrainer:
|
|||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def get_sai_model_spec(self, args):
|
||||||
|
return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
||||||
|
|
||||||
|
def update_metadata(self, metadata, args):
|
||||||
|
pass
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args):
|
||||||
@@ -521,10 +527,13 @@ class NetworkTrainer:
|
|||||||
unet_weight_dtype = torch.float8_e4m3fn
|
unet_weight_dtype = torch.float8_e4m3fn
|
||||||
te_weight_dtype = torch.float8_e4m3fn
|
te_weight_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
unet.to(accelerator.device) # this makes faster `to(dtype)` below
|
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||||
|
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||||
|
|
||||||
|
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.to(dtype=unet_weight_dtype) # this takes long time and large memory
|
unet.to(dtype=unet_weight_dtype)
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.requires_grad_(False)
|
t_enc.requires_grad_(False)
|
||||||
|
|
||||||
@@ -718,8 +727,11 @@ class NetworkTrainer:
|
|||||||
"ss_loss_type": args.loss_type,
|
"ss_loss_type": args.loss_type,
|
||||||
"ss_huber_schedule": args.huber_schedule,
|
"ss_huber_schedule": args.huber_schedule,
|
||||||
"ss_huber_c": args.huber_c,
|
"ss_huber_c": args.huber_c,
|
||||||
|
"ss_fp8_base": args.fp8_base,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.update_metadata(metadata, args) # architecture specific metadata
|
||||||
|
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
# save metadata of multiple datasets
|
# save metadata of multiple datasets
|
||||||
# NOTE: pack "ss_datasets" value as json one time
|
# NOTE: pack "ss_datasets" value as json one time
|
||||||
@@ -964,7 +976,7 @@ class NetworkTrainer:
|
|||||||
metadata["ss_epoch"] = str(epoch_no)
|
metadata["ss_epoch"] = str(epoch_no)
|
||||||
|
|
||||||
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
||||||
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
sai_metadata = self.get_sai_model_spec(args)
|
||||||
metadata_to_save.update(sai_metadata)
|
metadata_to_save.update(sai_metadata)
|
||||||
|
|
||||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||||
|
|||||||
Reference in New Issue
Block a user