mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
FLUX.1 LoRA supports CLIP-L
This commit is contained in:
@@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
Aug 27, 2024:
|
||||
|
||||
- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI.
|
||||
- `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
|
||||
- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
|
||||
|
||||
- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option).
|
||||
|
||||
Aug 25, 2024:
|
||||
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.
|
||||
Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0`
|
||||
|
||||
@@ -40,9 +40,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
assert (
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
# assert (
|
||||
# args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
if not args.network_train_unet_only:
|
||||
logger.info(
|
||||
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
|
||||
)
|
||||
|
||||
if args.max_token_length is not None:
|
||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||
@@ -137,12 +141,25 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
|
||||
if args.cache_text_encoder_outputs:
|
||||
if self.is_train_text_encoder(args):
|
||||
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
||||
else:
|
||||
return text_encoders # ignored
|
||||
else:
|
||||
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [True, False] if self.is_train_text_encoder(args) else [False, False]
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
False,
|
||||
is_partial=self.is_train_text_encoder(args),
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -190,9 +207,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move back to cpu
|
||||
logger.info("move text encoders back to cpu")
|
||||
text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu") # , dtype=torch.float32)
|
||||
if not self.is_train_text_encoder(args):
|
||||
logger.info("move CLIP-L back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
logger.info("move t5XXL back to cpu")
|
||||
text_encoders[1].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
@@ -297,7 +316,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
if t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
img_ids.requires_grad_(True)
|
||||
guidance_vec.requires_grad_(True)
|
||||
|
||||
@@ -384,7 +404,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -58,7 +58,7 @@ def sample_images(
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
@@ -66,7 +66,8 @@ def sample_images(
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
flux = accelerator.unwrap_model(flux)
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
if text_encoders is not None:
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
|
||||
prompts = load_prompts(args.sample_prompts)
|
||||
@@ -134,7 +135,7 @@ def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
flux: flux_models.Flux,
|
||||
text_encoders: List[CLIPTextModel],
|
||||
text_encoders: Optional[List[CLIPTextModel]],
|
||||
ae: flux_models.AutoEncoder,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
@@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
clip_l, t5xxl = models
|
||||
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
|
||||
l_tokens, t5_tokens = tokens[:2]
|
||||
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||
|
||||
@@ -81,6 +81,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
else:
|
||||
t5_out = None
|
||||
txt_ids = None
|
||||
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
|
||||
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
||||
|
||||
|
||||
@@ -401,7 +401,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
|
||||
# single_qkv_rank is not None and single_qkv_rank != rank
|
||||
# )
|
||||
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
|
||||
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
@@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||
|
||||
@@ -127,8 +127,15 @@ class NetworkTrainer:
|
||||
return None
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
"""
|
||||
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
|
||||
"""
|
||||
return text_encoders
|
||||
|
||||
# returns a list of bool values indicating whether each text encoder should be trained
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
|
||||
|
||||
def is_train_text_encoder(self, args):
|
||||
return not args.network_train_unet_only
|
||||
|
||||
@@ -136,11 +143,6 @@ class NetworkTrainer:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype)
|
||||
return encoder_hidden_states
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||
return noise_pred
|
||||
@@ -313,7 +315,7 @@ class NetworkTrainer:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
@@ -437,8 +439,10 @@ class NetworkTrainer:
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
for t_enc in text_encoders:
|
||||
t_enc.gradient_checkpointing_enable()
|
||||
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
if flag:
|
||||
if t_enc.supports_gradient_checkpointing:
|
||||
t_enc.gradient_checkpointing_enable()
|
||||
del t_enc
|
||||
network.enable_gradient_checkpointing() # may have no effect
|
||||
|
||||
@@ -522,14 +526,17 @@ class NetworkTrainer:
|
||||
|
||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||
# Experimental Feature: Put base model into fp8 to save vram
|
||||
if args.fp8_base:
|
||||
if args.fp8_base or args.fp8_base_unet:
|
||||
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||
assert (
|
||||
args.mixed_precision != "no"
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training.")
|
||||
accelerator.print("enable fp8 training for U-Net.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
te_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
if not args.fp8_base_unet:
|
||||
accelerator.print("enable fp8 training for Text Encoder.")
|
||||
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
@@ -546,19 +553,18 @@ class NetworkTrainer:
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(
|
||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||
t_enc.encoder.embeddings.to(
|
||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if args.deepspeed:
|
||||
flags = self.get_text_encoders_train_flags(args, text_encoders)
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||
args,
|
||||
unet=unet if train_unet else None,
|
||||
text_encoder1=text_encoders[0] if train_text_encoder else None,
|
||||
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
||||
text_encoder1=text_encoders[0] if flags[0] else None,
|
||||
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
|
||||
network=network,
|
||||
)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -571,11 +577,14 @@ class NetworkTrainer:
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
text_encoders = [
|
||||
(accelerator.prepare(t_enc) if flag else t_enc)
|
||||
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
|
||||
]
|
||||
if len(text_encoders) > 1:
|
||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||
text_encoder = text_encoders
|
||||
else:
|
||||
text_encoder = accelerator.prepare(text_encoder)
|
||||
text_encoders = [text_encoder]
|
||||
text_encoder = text_encoders[0]
|
||||
else:
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
@@ -587,11 +596,11 @@ class NetworkTrainer:
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
for t_enc in text_encoders:
|
||||
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
t_enc.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if train_text_encoder:
|
||||
if frag:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
else:
|
||||
@@ -736,6 +745,7 @@ class NetworkTrainer:
|
||||
"ss_huber_schedule": args.huber_schedule,
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": args.fp8_base,
|
||||
"ss_fp8_base_unet": args.fp8_base_unet,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1004,6 +1014,7 @@ class NetworkTrainer:
|
||||
for t_enc in text_encoders:
|
||||
del t_enc
|
||||
text_encoders = []
|
||||
text_encoder = None
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
@@ -1018,7 +1029,7 @@ class NetworkTrainer:
|
||||
# log device and dtype for each model
|
||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||
for t_enc in text_encoders:
|
||||
logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}")
|
||||
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -1073,12 +1084,17 @@ class NetworkTrainer:
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
else:
|
||||
if (
|
||||
text_encoder_conds is None
|
||||
or len(text_encoder_conds) == 0
|
||||
or text_encoder_conds[0] is None
|
||||
or train_text_encoder
|
||||
):
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# SD only
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
tokenizers[0],
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
@@ -1088,13 +1104,18 @@ class NetworkTrainer:
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||
@@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
parser.add_argument(
|
||||
"--fp8_base_unet",
|
||||
action="store_true",
|
||||
help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
|
||||
" / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
||||
|
||||
Reference in New Issue
Block a user