mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support individual LR for CLIP-L/T5XXL
This commit is contained in:
@@ -11,6 +11,9 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 10, 2024:
|
||||||
|
In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same.
|
||||||
|
|
||||||
Sep 9, 2024:
|
Sep 9, 2024:
|
||||||
Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used.
|
Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used.
|
||||||
|
|
||||||
@@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
|
|||||||
- Remove `--network_train_unet_only` from your command.
|
- Remove `--network_train_unet_only` from your command.
|
||||||
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
|
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
|
||||||
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
|
||||||
|
- The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL.
|
||||||
- The trained LoRA can be used with ComfyUI.
|
- The trained LoRA can be used with ComfyUI.
|
||||||
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
|
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.
|
||||||
|
|
||||||
|
|||||||
@@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
# make sure text_encoder_lr as list of two elements
|
||||||
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
if text_encoder_lr is None or len(text_encoder_lr) == 0:
|
||||||
# if (
|
text_encoder_lr = [default_lr, default_lr]
|
||||||
# self.loraplus_lr_ratio is not None
|
elif len(text_encoder_lr) == 1:
|
||||||
# or self.loraplus_text_encoder_lr_ratio is not None
|
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]
|
||||||
# or self.loraplus_unet_lr_ratio is not None
|
|
||||||
# ):
|
|
||||||
# assert (
|
|
||||||
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
|
||||||
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
|
||||||
|
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
|
|
||||||
all_params = []
|
all_params = []
|
||||||
lr_descriptions = []
|
lr_descriptions = []
|
||||||
|
|
||||||
def assemble_params(loras, lr, ratio):
|
def assemble_params(loras, lr, loraplus_ratio):
|
||||||
param_groups = {"lora": {}, "plus": {}}
|
param_groups = {"lora": {}, "plus": {}}
|
||||||
for lora in loras:
|
for lora in loras:
|
||||||
for name, param in lora.named_parameters():
|
for name, param in lora.named_parameters():
|
||||||
if ratio is not None and "lora_up" in name:
|
if loraplus_ratio is not None and "lora_up" in name:
|
||||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||||
else:
|
else:
|
||||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||||
@@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if lr is not None:
|
if lr is not None:
|
||||||
if key == "plus":
|
if key == "plus":
|
||||||
param_data["lr"] = lr * ratio
|
param_data["lr"] = lr * loraplus_ratio
|
||||||
else:
|
else:
|
||||||
param_data["lr"] = lr
|
param_data["lr"] = lr
|
||||||
|
|
||||||
@@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return params, descriptions
|
return params, descriptions
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
params, descriptions = assemble_params(
|
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||||
self.text_encoder_loras,
|
|
||||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
# split text encoder loras for te1 and te3
|
||||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)]
|
||||||
)
|
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
|
||||||
all_params.extend(params)
|
if len(te1_loras) > 0:
|
||||||
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||||
|
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
|
||||||
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
|
||||||
|
if len(te3_loras) > 0:
|
||||||
|
logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}")
|
||||||
|
params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio)
|
||||||
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
# if self.block_lr:
|
|
||||||
# is_sdxl = False
|
|
||||||
# for lora in self.unet_loras:
|
|
||||||
# if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
|
|
||||||
# is_sdxl = True
|
|
||||||
# break
|
|
||||||
|
|
||||||
# # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
|
||||||
# block_idx_to_lora = {}
|
|
||||||
# for lora in self.unet_loras:
|
|
||||||
# idx = get_block_index(lora.lora_name, is_sdxl)
|
|
||||||
# if idx not in block_idx_to_lora:
|
|
||||||
# block_idx_to_lora[idx] = []
|
|
||||||
# block_idx_to_lora[idx].append(lora)
|
|
||||||
|
|
||||||
# # blockごとにパラメータを設定する
|
|
||||||
# for idx, block_loras in block_idx_to_lora.items():
|
|
||||||
# params, descriptions = assemble_params(
|
|
||||||
# block_loras,
|
|
||||||
# (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
|
|
||||||
# self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
|
||||||
# )
|
|
||||||
# all_params.extend(params)
|
|
||||||
# lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
|
||||||
|
|
||||||
# else:
|
|
||||||
params, descriptions = assemble_params(
|
params, descriptions = assemble_params(
|
||||||
self.unet_loras,
|
self.unet_loras,
|
||||||
unet_lr if unet_lr is not None else default_lr,
|
unet_lr if unet_lr is not None else default_lr,
|
||||||
|
|||||||
@@ -466,9 +466,17 @@ class NetworkTrainer:
|
|||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
# 後方互換性を確保するよ
|
# make backward compatibility for text_encoder_lr
|
||||||
|
support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
|
||||||
|
if support_multiple_lrs:
|
||||||
|
text_encoder_lr = args.text_encoder_lr
|
||||||
|
else:
|
||||||
|
text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
|
||||||
try:
|
try:
|
||||||
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
if support_multiple_lrs:
|
||||||
|
results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||||
|
else:
|
||||||
|
results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||||
if type(results) is tuple:
|
if type(results) is tuple:
|
||||||
trainable_params = results[0]
|
trainable_params = results[0]
|
||||||
lr_descriptions = results[1]
|
lr_descriptions = results[1]
|
||||||
@@ -476,11 +484,7 @@ class NetworkTrainer:
|
|||||||
trainable_params = results
|
trainable_params = results
|
||||||
lr_descriptions = None
|
lr_descriptions = None
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
# logger.warning(f"{e}")
|
trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
|
||||||
# accelerator.print(
|
|
||||||
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
|
||||||
# )
|
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
|
||||||
lr_descriptions = None
|
lr_descriptions = None
|
||||||
|
|
||||||
# if len(trainable_params) == 0:
|
# if len(trainable_params) == 0:
|
||||||
@@ -713,7 +717,7 @@ class NetworkTrainer:
|
|||||||
"ss_training_started_at": training_started_at, # unix timestamp
|
"ss_training_started_at": training_started_at, # unix timestamp
|
||||||
"ss_output_name": args.output_name,
|
"ss_output_name": args.output_name,
|
||||||
"ss_learning_rate": args.learning_rate,
|
"ss_learning_rate": args.learning_rate,
|
||||||
"ss_text_encoder_lr": args.text_encoder_lr,
|
"ss_text_encoder_lr": text_encoder_lr,
|
||||||
"ss_unet_lr": args.unet_lr,
|
"ss_unet_lr": args.unet_lr,
|
||||||
"ss_num_train_images": train_dataset_group.num_train_images,
|
"ss_num_train_images": train_dataset_group.num_train_images,
|
||||||
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
||||||
@@ -760,8 +764,8 @@ class NetworkTrainer:
|
|||||||
"ss_loss_type": args.loss_type,
|
"ss_loss_type": args.loss_type,
|
||||||
"ss_huber_schedule": args.huber_schedule,
|
"ss_huber_schedule": args.huber_schedule,
|
||||||
"ss_huber_c": args.huber_c,
|
"ss_huber_c": args.huber_c,
|
||||||
"ss_fp8_base": args.fp8_base,
|
"ss_fp8_base": bool(args.fp8_base),
|
||||||
"ss_fp8_base_unet": args.fp8_base_unet,
|
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.update_metadata(metadata, args) # architecture specific metadata
|
self.update_metadata(metadata, args) # architecture specific metadata
|
||||||
@@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
parser.add_argument(
|
||||||
|
"--text_encoder_lr",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
nargs="*",
|
||||||
|
help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fp8_base_unet",
|
"--fp8_base_unet",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
Reference in New Issue
Block a user