mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support individual LR for CLIP-L/T5XXL
This commit is contained in:
@@ -11,6 +11,9 @@ The command to install PyTorch is as follows:
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Sep 10, 2024:
|
||||
In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same.
|
||||
|
||||
Sep 9, 2024:
|
||||
Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used.
|
||||
|
||||
@@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
|
||||
- Remove `--network_train_unet_only` from your command.
|
||||
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
|
||||
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||
- The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL.
|
||||
- The trained LoRA can be used with ComfyUI.
|
||||
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
|
||||
|
||||
|
||||
@@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
||||
# if (
|
||||
# self.loraplus_lr_ratio is not None
|
||||
# or self.loraplus_text_encoder_lr_ratio is not None
|
||||
# or self.loraplus_unet_lr_ratio is not None
|
||||
# ):
|
||||
# assert (
|
||||
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
||||
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# make sure text_encoder_lr as list of two elements
|
||||
if text_encoder_lr is None or len(text_encoder_lr) == 0:
|
||||
text_encoder_lr = [default_lr, default_lr]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, ratio):
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
for lora in loras:
|
||||
for name, param in lora.named_parameters():
|
||||
if ratio is not None and "lora_up" in name:
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
@@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * ratio
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
|
||||
@@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module):
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
||||
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
|
||||
# split text encoder loras for te1 and te3
|
||||
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)]
|
||||
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
|
||||
if len(te1_loras) > 0:
|
||||
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
|
||||
if len(te3_loras) > 0:
|
||||
logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}")
|
||||
params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
# if self.block_lr:
|
||||
# is_sdxl = False
|
||||
# for lora in self.unet_loras:
|
||||
# if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
|
||||
# is_sdxl = True
|
||||
# break
|
||||
|
||||
# # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
||||
# block_idx_to_lora = {}
|
||||
# for lora in self.unet_loras:
|
||||
# idx = get_block_index(lora.lora_name, is_sdxl)
|
||||
# if idx not in block_idx_to_lora:
|
||||
# block_idx_to_lora[idx] = []
|
||||
# block_idx_to_lora[idx].append(lora)
|
||||
|
||||
# # blockごとにパラメータを設定する
|
||||
# for idx, block_loras in block_idx_to_lora.items():
|
||||
# params, descriptions = assemble_params(
|
||||
# block_loras,
|
||||
# (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
|
||||
# self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
# )
|
||||
# all_params.extend(params)
|
||||
# lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
# else:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
|
||||
@@ -466,9 +466,17 @@ class NetworkTrainer:
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 後方互換性を確保するよ
|
||||
# make backward compatibility for text_encoder_lr
|
||||
support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
|
||||
if support_multiple_lrs:
|
||||
text_encoder_lr = args.text_encoder_lr
|
||||
else:
|
||||
text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
|
||||
try:
|
||||
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
if support_multiple_lrs:
|
||||
results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
else:
|
||||
results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||
if type(results) is tuple:
|
||||
trainable_params = results[0]
|
||||
lr_descriptions = results[1]
|
||||
@@ -476,11 +484,7 @@ class NetworkTrainer:
|
||||
trainable_params = results
|
||||
lr_descriptions = None
|
||||
except TypeError as e:
|
||||
# logger.warning(f"{e}")
|
||||
# accelerator.print(
|
||||
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||
# )
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
|
||||
lr_descriptions = None
|
||||
|
||||
# if len(trainable_params) == 0:
|
||||
@@ -713,7 +717,7 @@ class NetworkTrainer:
|
||||
"ss_training_started_at": training_started_at, # unix timestamp
|
||||
"ss_output_name": args.output_name,
|
||||
"ss_learning_rate": args.learning_rate,
|
||||
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||
"ss_text_encoder_lr": text_encoder_lr,
|
||||
"ss_unet_lr": args.unet_lr,
|
||||
"ss_num_train_images": train_dataset_group.num_train_images,
|
||||
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
||||
@@ -760,8 +764,8 @@ class NetworkTrainer:
|
||||
"ss_loss_type": args.loss_type,
|
||||
"ss_huber_schedule": args.huber_schedule,
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": args.fp8_base,
|
||||
"ss_fp8_base_unet": args.fp8_base_unet,
|
||||
"ss_fp8_base": bool(args.fp8_base),
|
||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp8_base_unet",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user