From 0087a46e14c8e568982cbe3a5d9b9c561b175abf Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 27 Aug 2024 19:59:40 +0900 Subject: [PATCH] FLUX.1 LoRA supports CLIP-L --- README.md | 8 ++++ flux_train_network.py | 40 +++++++++++++----- library/flux_train_utils.py | 8 ++-- library/strategy_flux.py | 3 +- networks/lora_flux.py | 4 +- train_network.py | 81 ++++++++++++++++++++++++------------- 6 files changed, 101 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 562dcdb2..1203b5eb 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 27, 2024: + +- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI. + - `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA. +- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution. + +- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option). + Aug 25, 2024: Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`. Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0` diff --git a/flux_train_network.py b/flux_train_network.py index 82f77a77..1a40de61 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -40,9 +40,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): train_dataset_group.is_text_encoder_output_cacheable() ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" - assert ( - args.network_train_unet_only or not args.cache_text_encoder_outputs - ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + # assert ( + # args.network_train_unet_only or not args.cache_text_encoder_outputs + # ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if not args.network_train_unet_only: + logger.info( + "network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません" + ) if args.max_token_length is not None: logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") @@ -137,12 +141,25 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask) def get_models_for_text_encoding(self, args, accelerator, text_encoders): - return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])] + if args.cache_text_encoder_outputs: + if self.is_train_text_encoder(args): + return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached + else: + return text_encoders # ignored + else: + return text_encoders # both CLIP-L and T5XXL are needed for encoding + + def get_text_encoders_train_flags(self, args, text_encoders): + return [True, False] if self.is_train_text_encoder(args) else [False, False] def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: return strategy_flux.FluxTextEncoderOutputsCachingStrategy( - args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask + args.cache_text_encoder_outputs_to_disk, + None, + False, + is_partial=self.is_train_text_encoder(args), + apply_t5_attn_mask=args.apply_t5_attn_mask, ) else: return None @@ -190,9 +207,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): accelerator.wait_for_everyone() # move back to cpu - logger.info("move text encoders back to cpu") - text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU - text_encoders[1].to("cpu") # , dtype=torch.float32) + if not self.is_train_text_encoder(args): + logger.info("move CLIP-L back to cpu") + text_encoders[0].to("cpu") + logger.info("move t5XXL back to cpu") + text_encoders[1].to("cpu") clean_memory_on_device(accelerator.device) if not args.lowram: @@ -297,7 +316,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - t.requires_grad_(True) + if t.dtype.is_floating_point: + t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) @@ -384,7 +404,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift def is_text_encoder_not_needed_for_training(self, args): - return args.cache_text_encoder_outputs + return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args) def setup_parser() -> argparse.ArgumentParser: diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 75f70a54..a8e94ac0 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -58,7 +58,7 @@ def sample_images( logger.info("") logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}") - if not os.path.isfile(args.sample_prompts): + if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None: logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") return @@ -66,7 +66,8 @@ def sample_images( # unwrap unet and text_encoder(s) flux = accelerator.unwrap_model(flux) - text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] + if text_encoders is not None: + text_encoders = [accelerator.unwrap_model(te) for te in text_encoders] # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = load_prompts(args.sample_prompts) @@ -134,7 +135,7 @@ def sample_image_inference( accelerator: Accelerator, args: argparse.Namespace, flux: flux_models.Flux, - text_encoders: List[CLIPTextModel], + text_encoders: Optional[List[CLIPTextModel]], ae: flux_models.AutoEncoder, save_dir, prompt_dict, @@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps( elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index d52b3b8d..5d083913 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -60,7 +60,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - clip_l, t5xxl = models + clip_l, t5xxl = models if len(models) == 2 else (models[0], None) l_tokens, t5_tokens = tokens[:2] t5_attn_mask = tokens[2] if len(tokens) > 2 else None @@ -81,6 +81,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy): else: t5_out = None txt_ids = None + t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 07a80f0b..fcb56a46 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -401,7 +401,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( # single_qkv_rank is not None and single_qkv_rank != rank # ) - split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined module_class = LoRAInfModule if for_inference else LoRAModule @@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module): # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"] LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2" diff --git a/train_network.py b/train_network.py index cab0ec52..048c7e7b 100644 --- a/train_network.py +++ b/train_network.py @@ -127,8 +127,15 @@ class NetworkTrainer: return None def get_models_for_text_encoding(self, args, accelerator, text_encoders): + """ + Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models. + """ return text_encoders + # returns a list of bool values indicating whether each text encoder should be trained + def get_text_encoders_train_flags(self, args, text_encoders): + return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders) + def is_train_text_encoder(self, args): return not args.network_train_unet_only @@ -136,11 +143,6 @@ class NetworkTrainer: for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype): - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype) - return encoder_hidden_states - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -313,7 +315,7 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: - train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: @@ -437,8 +439,10 @@ class NetworkTrainer: if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - for t_enc in text_encoders: - t_enc.gradient_checkpointing_enable() + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): + if flag: + if t_enc.supports_gradient_checkpointing: + t_enc.gradient_checkpointing_enable() del t_enc network.enable_gradient_checkpointing() # may have no effect @@ -522,14 +526,17 @@ class NetworkTrainer: unet_weight_dtype = te_weight_dtype = weight_dtype # Experimental Feature: Put base model into fp8 to save vram - if args.fp8_base: + if args.fp8_base or args.fp8_base_unet: assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" assert ( args.mixed_precision != "no" ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" - accelerator.print("enable fp8 training.") + accelerator.print("enable fp8 training for U-Net.") unet_weight_dtype = torch.float8_e4m3fn - te_weight_dtype = torch.float8_e4m3fn + + if not args.fp8_base_unet: + accelerator.print("enable fp8 training for Text Encoder.") + te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn # unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM # unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory @@ -546,19 +553,18 @@ class NetworkTrainer: t_enc.to(dtype=te_weight_dtype) if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"): # nn.Embedding not support FP8 - t_enc.text_model.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"): - t_enc.encoder.embeddings.to( - dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) + t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if args.deepspeed: + flags = self.get_text_encoders_train_flags(args, text_encoders) ds_model = deepspeed_utils.prepare_deepspeed_model( args, unet=unet if train_unet else None, - text_encoder1=text_encoders[0] if train_text_encoder else None, - text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, + text_encoder1=text_encoders[0] if flags[0] else None, + text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None, network=network, ) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -571,11 +577,14 @@ class NetworkTrainer: else: unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: + text_encoders = [ + (accelerator.prepare(t_enc) if flag else t_enc) + for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)) + ] if len(text_encoders) > 1: - text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] + text_encoder = text_encoders else: - text_encoder = accelerator.prepare(text_encoder) - text_encoders = [text_encoder] + text_encoder = text_encoders[0] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set @@ -587,11 +596,11 @@ class NetworkTrainer: if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - for t_enc in text_encoders: + for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - if train_text_encoder: + if frag: t_enc.text_model.embeddings.requires_grad_(True) else: @@ -736,6 +745,7 @@ class NetworkTrainer: "ss_huber_schedule": args.huber_schedule, "ss_huber_c": args.huber_c, "ss_fp8_base": args.fp8_base, + "ss_fp8_base_unet": args.fp8_base_unet, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1004,6 +1014,7 @@ class NetworkTrainer: for t_enc in text_encoders: del t_enc text_encoders = [] + text_encoder = None # For --sample_at_first self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet) @@ -1018,7 +1029,7 @@ class NetworkTrainer: # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}") for t_enc in text_encoders: - logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}") + logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}") clean_memory_on_device(accelerator.device) @@ -1073,12 +1084,17 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - else: + if ( + text_encoder_conds is None + or len(text_encoder_conds) == 0 + or text_encoder_conds[0] is None + or train_text_encoder + ): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: # SD only - text_encoder_conds = get_weighted_text_embeddings( + encoded_text_encoder_conds = get_weighted_text_embeddings( tokenizers[0], text_encoder, batch["captions"], @@ -1088,13 +1104,18 @@ class NetworkTrainer: ) else: input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( + encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, ) if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] + + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( @@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument( + "--fp8_base_unet", + action="store_true", + help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16" + " / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16", + ) parser.add_argument( "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"