support individual LR for CLIP-L/T5XXL

This commit is contained in:
Kohya S
2024-09-10 20:32:09 +09:00
parent d29af146b8
commit d10ff62a78
3 changed files with 49 additions and 58 deletions

View File

@@ -11,6 +11,9 @@ The command to install PyTorch is as follows:
### Recent Updates ### Recent Updates
Sep 10, 2024:
In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same.
Sep 9, 2024: Sep 9, 2024:
Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used.
@@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times
- Remove `--network_train_unet_only` from your command. - Remove `--network_train_unet_only` from your command.
- Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time.
- T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available.
- The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL.
- The trained LoRA can be used with ComfyUI. - The trained LoRA can be used with ComfyUI.
- Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet.

View File

@@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): # make sure text_encoder_lr as list of two elements
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) if text_encoder_lr is None or len(text_encoder_lr) == 0:
# if ( text_encoder_lr = [default_lr, default_lr]
# self.loraplus_lr_ratio is not None elif len(text_encoder_lr) == 1:
# or self.loraplus_text_encoder_lr_ratio is not None text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]
# or self.loraplus_unet_lr_ratio is not None
# ):
# assert (
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
self.requires_grad_(True) self.requires_grad_(True)
all_params = [] all_params = []
lr_descriptions = [] lr_descriptions = []
def assemble_params(loras, lr, ratio): def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}} param_groups = {"lora": {}, "plus": {}}
for lora in loras: for lora in loras:
for name, param in lora.named_parameters(): for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name: if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else: else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param param_groups["lora"][f"{lora.lora_name}.{name}"] = param
@@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module):
if lr is not None: if lr is not None:
if key == "plus": if key == "plus":
param_data["lr"] = lr * ratio param_data["lr"] = lr * loraplus_ratio
else: else:
param_data["lr"] = lr param_data["lr"] = lr
@@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module):
return params, descriptions return params, descriptions
if self.text_encoder_loras: if self.text_encoder_loras:
params, descriptions = assemble_params( loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr, # split text encoder loras for te1 and te3
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)]
) te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
all_params.extend(params) if len(te1_loras) > 0:
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te3_loras) > 0:
logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}")
params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions])
if self.unet_loras: if self.unet_loras:
# if self.block_lr:
# is_sdxl = False
# for lora in self.unet_loras:
# if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
# is_sdxl = True
# break
# # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
# block_idx_to_lora = {}
# for lora in self.unet_loras:
# idx = get_block_index(lora.lora_name, is_sdxl)
# if idx not in block_idx_to_lora:
# block_idx_to_lora[idx] = []
# block_idx_to_lora[idx].append(lora)
# # blockごとにパラメータを設定する
# for idx, block_loras in block_idx_to_lora.items():
# params, descriptions = assemble_params(
# block_loras,
# (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
# self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
# )
# all_params.extend(params)
# lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
# else:
params, descriptions = assemble_params( params, descriptions = assemble_params(
self.unet_loras, self.unet_loras,
unet_lr if unet_lr is not None else default_lr, unet_lr if unet_lr is not None else default_lr,

View File

@@ -466,9 +466,17 @@ class NetworkTrainer:
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
# 後方互換性を確保するよ # make backward compatibility for text_encoder_lr
support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs")
if support_multiple_lrs:
text_encoder_lr = args.text_encoder_lr
else:
text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0]
try: try:
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) if support_multiple_lrs:
results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate)
else:
results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate)
if type(results) is tuple: if type(results) is tuple:
trainable_params = results[0] trainable_params = results[0]
lr_descriptions = results[1] lr_descriptions = results[1]
@@ -476,11 +484,7 @@ class NetworkTrainer:
trainable_params = results trainable_params = results
lr_descriptions = None lr_descriptions = None
except TypeError as e: except TypeError as e:
# logger.warning(f"{e}") trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr)
# accelerator.print(
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
# )
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
lr_descriptions = None lr_descriptions = None
# if len(trainable_params) == 0: # if len(trainable_params) == 0:
@@ -713,7 +717,7 @@ class NetworkTrainer:
"ss_training_started_at": training_started_at, # unix timestamp "ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name, "ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate, "ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr, "ss_text_encoder_lr": text_encoder_lr,
"ss_unet_lr": args.unet_lr, "ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset_group.num_train_images, "ss_num_train_images": train_dataset_group.num_train_images,
"ss_num_reg_images": train_dataset_group.num_reg_images, "ss_num_reg_images": train_dataset_group.num_reg_images,
@@ -760,8 +764,8 @@ class NetworkTrainer:
"ss_loss_type": args.loss_type, "ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule, "ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c, "ss_huber_c": args.huber_c,
"ss_fp8_base": args.fp8_base, "ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": args.fp8_base_unet, "ss_fp8_base_unet": bool(args.fp8_base_unet),
} }
self.update_metadata(metadata, args) # architecture specific metadata self.update_metadata(metadata, args) # architecture specific metadata
@@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser:
) )
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") parser.add_argument(
"--text_encoder_lr",
type=float,
default=None,
nargs="*",
help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能",
)
parser.add_argument( parser.add_argument(
"--fp8_base_unet", "--fp8_base_unet",
action="store_true", action="store_true",