diff --git a/README.md b/README.md index 126516f9..b5799dd6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 10, 2024: +In FLUX.1 LoRA training, individual learning rates can be specified for CLIP-L and T5XXL. By specifying multiple numbers in `--text_encoder_lr`, you can set the learning rates for CLIP-L and T5XXL separately. Specify like `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. + Sep 9, 2024: Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. @@ -142,6 +145,7 @@ The effect of `--timestep_sampling sigmoid` and `--sigmoid_scale` (when `--times - Remove `--network_train_unet_only` from your command. - Add `train_t5xxl=True` to `--network_args` to train T5XXL LoRA. CLIP-L is also trained at the same time. - T5XXL output can be cached for CLIP-L LoRA training. So, `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is also available. + - The learning rates for CLIP-L and T5XXL can be specified separately. Multiple numbers can be specified in `--text_encoder_lr`. For example, `--text_encoder_lr 1e-4 1e-5`. If you specify only one, the learning rates for CLIP-L and T5XXL will be the same. If `--text_encoder_lr` is not specified, the default learning rate `--learning_rate` is used for both CLIP-L and T5XXL. - The trained LoRA can be used with ComfyUI. - Note: `flux_extract_lora.py`, `convert_flux_lora.py`and `merge_flux_lora.py` do not support CLIP-L and T5XXL LoRA yet. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ab9ccc4d..d540c221 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -786,28 +786,23 @@ class LoRANetwork(torch.nn.Module): logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") - # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): - # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) - # if ( - # self.loraplus_lr_ratio is not None - # or self.loraplus_text_encoder_lr_ratio is not None - # or self.loraplus_unet_lr_ratio is not None - # ): - # assert ( - # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() - # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + if text_encoder_lr is None or len(text_encoder_lr) == 0: + text_encoder_lr = [default_lr, default_lr] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] self.requires_grad_(True) all_params = [] lr_descriptions = [] - def assemble_params(loras, lr, ratio): + def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: + if loraplus_ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param @@ -822,7 +817,7 @@ class LoRANetwork(torch.nn.Module): if lr is not None: if key == "plus": - param_data["lr"] = lr * ratio + param_data["lr"] = lr * loraplus_ratio else: param_data["lr"] = lr @@ -836,41 +831,23 @@ class LoRANetwork(torch.nn.Module): return params, descriptions if self.text_encoder_loras: - params, descriptions = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, - ) - all_params.extend(params) - lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) if self.unet_loras: - # if self.block_lr: - # is_sdxl = False - # for lora in self.unet_loras: - # if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: - # is_sdxl = True - # break - - # # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 - # block_idx_to_lora = {} - # for lora in self.unet_loras: - # idx = get_block_index(lora.lora_name, is_sdxl) - # if idx not in block_idx_to_lora: - # block_idx_to_lora[idx] = [] - # block_idx_to_lora[idx].append(lora) - - # # blockごとにパラメータを設定する - # for idx, block_loras in block_idx_to_lora.items(): - # params, descriptions = assemble_params( - # block_loras, - # (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), - # self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, - # ) - # all_params.extend(params) - # lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) - - # else: params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, diff --git a/train_network.py b/train_network.py index ad97491d..e45db052 100644 --- a/train_network.py +++ b/train_network.py @@ -466,9 +466,17 @@ class NetworkTrainer: # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - # 後方互換性を確保するよ + # make backward compatibility for text_encoder_lr + support_multiple_lrs = hasattr(network, "prepare_optimizer_params_with_multiple_te_lrs") + if support_multiple_lrs: + text_encoder_lr = args.text_encoder_lr + else: + text_encoder_lr = None if args.text_encoder_lr is None or len(args.text_encoder_lr) == 0 else args.text_encoder_lr[0] try: - results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if support_multiple_lrs: + results = network.prepare_optimizer_params_with_multiple_te_lrs(text_encoder_lr, args.unet_lr, args.learning_rate) + else: + results = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr, args.learning_rate) if type(results) is tuple: trainable_params = results[0] lr_descriptions = results[1] @@ -476,11 +484,7 @@ class NetworkTrainer: trainable_params = results lr_descriptions = None except TypeError as e: - # logger.warning(f"{e}") - # accelerator.print( - # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - # ) - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(text_encoder_lr, args.unet_lr) lr_descriptions = None # if len(trainable_params) == 0: @@ -713,7 +717,7 @@ class NetworkTrainer: "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, + "ss_text_encoder_lr": text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset_group.num_train_images, "ss_num_reg_images": train_dataset_group.num_reg_images, @@ -760,8 +764,8 @@ class NetworkTrainer: "ss_loss_type": args.loss_type, "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, - "ss_fp8_base": args.fp8_base, - "ss_fp8_base_unet": args.fp8_base_unet, + "ss_fp8_base": bool(args.fp8_base), + "ss_fp8_base_unet": bool(args.fp8_base_unet), } self.update_metadata(metadata, args) # architecture specific metadata @@ -1303,7 +1307,13 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") - parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--text_encoder_lr", + type=float, + default=None, + nargs="*", + help="learning rate for Text Encoder, can be multiple / Text Encoderの学習率、複数指定可能", + ) parser.add_argument( "--fp8_base_unet", action="store_true",