FLUX.1 LoRA supports CLIP-L

This commit is contained in:
Kohya S
2024-08-27 19:59:40 +09:00
parent 72287d39c7
commit 0087a46e14
6 changed files with 101 additions and 43 deletions

View File

@@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 27, 2024:
- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI.
- `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option).
Aug 25, 2024:
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.
Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0`

View File

@@ -40,9 +40,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
assert (
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
# assert (
# args.network_train_unet_only or not args.cache_text_encoder_outputs
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
if not args.network_train_unet_only:
logger.info(
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
)
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
@@ -137,12 +141,25 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
if args.cache_text_encoder_outputs:
if self.is_train_text_encoder(args):
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
else:
return text_encoders # ignored
else:
return text_encoders # both CLIP-L and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [True, False] if self.is_train_text_encoder(args) else [False, False]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask
args.cache_text_encoder_outputs_to_disk,
None,
False,
is_partial=self.is_train_text_encoder(args),
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
@@ -190,9 +207,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
accelerator.wait_for_everyone()
# move back to cpu
logger.info("move text encoders back to cpu")
text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu") # , dtype=torch.float32)
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
@@ -297,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype.is_floating_point:
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
@@ -384,7 +404,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -58,7 +58,7 @@ def sample_images(
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
@@ -66,6 +66,7 @@ def sample_images(
# unwrap unet and text_encoder(s)
flux = accelerator.unwrap_model(flux)
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
@@ -134,7 +135,7 @@ def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
flux: flux_models.Flux,
text_encoders: List[CLIPTextModel],
text_encoders: Optional[List[CLIPTextModel]],
ae: flux_models.AutoEncoder,
save_dir,
prompt_dict,
@@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps(
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)

View File

@@ -60,7 +60,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
@@ -81,6 +81,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer

View File

@@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module):
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"

View File

@@ -127,8 +127,15 @@ class NetworkTrainer:
return None
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
"""
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
"""
return text_encoders
# returns a list of bool values indicating whether each text encoder should be trained
def get_text_encoders_train_flags(self, args, text_encoders):
return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
def is_train_text_encoder(self, args):
return not args.network_train_unet_only
@@ -136,11 +143,6 @@ class NetworkTrainer:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype)
return encoder_hidden_states
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred
@@ -437,7 +439,9 @@ class NetworkTrainer:
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
for t_enc in text_encoders:
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
if flag:
if t_enc.supports_gradient_checkpointing:
t_enc.gradient_checkpointing_enable()
del t_enc
network.enable_gradient_checkpointing() # may have no effect
@@ -522,14 +526,17 @@ class NetworkTrainer:
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
if args.fp8_base or args.fp8_base_unet:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
accelerator.print("enable fp8 training for U-Net.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
if not args.fp8_base_unet:
accelerator.print("enable fp8 training for Text Encoder.")
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
@@ -546,19 +553,18 @@ class NetworkTrainer:
t_enc.to(dtype=te_weight_dtype)
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
t_enc.encoder.embeddings.to(
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
flags = self.get_text_encoders_train_flags(args, text_encoders)
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoders[0] if train_text_encoder else None,
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
text_encoder1=text_encoders[0] if flags[0] else None,
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -571,11 +577,14 @@ class NetworkTrainer:
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
text_encoders = [
(accelerator.prepare(t_enc) if flag else t_enc)
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
]
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
text_encoder = text_encoders
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
text_encoder = text_encoders[0]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
@@ -587,11 +596,11 @@ class NetworkTrainer:
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc in text_encoders:
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
if frag:
t_enc.text_model.embeddings.requires_grad_(True)
else:
@@ -736,6 +745,7 @@ class NetworkTrainer:
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
"ss_fp8_base": args.fp8_base,
"ss_fp8_base_unet": args.fp8_base_unet,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1004,6 +1014,7 @@ class NetworkTrainer:
for t_enc in text_encoders:
del t_enc
text_encoders = []
text_encoder = None
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
@@ -1018,7 +1029,7 @@ class NetworkTrainer:
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
for t_enc in text_encoders:
logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}")
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
clean_memory_on_device(accelerator.device)
@@ -1073,12 +1084,17 @@ class NetworkTrainer:
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
else:
if (
text_encoder_conds is None
or len(text_encoder_conds) == 0
or text_encoder_conds[0] is None
or train_text_encoder
):
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
# SD only
text_encoder_conds = get_weighted_text_embeddings(
encoded_text_encoder_conds = get_weighted_text_embeddings(
tokenizers[0],
text_encoder,
batch["captions"],
@@ -1088,13 +1104,18 @@ class NetworkTrainer:
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
@@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument(
"--fp8_base_unet",
action="store_true",
help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
" / U-NetまたはDiTにfp8を使用する。Text Encoderはfp16またはbf16",
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"