mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
FLUX.1 LoRA supports CLIP-L
This commit is contained in:
@@ -127,8 +127,15 @@ class NetworkTrainer:
|
||||
return None
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
"""
|
||||
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
|
||||
"""
|
||||
return text_encoders
|
||||
|
||||
# returns a list of bool values indicating whether each text encoder should be trained
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
|
||||
|
||||
def is_train_text_encoder(self, args):
|
||||
return not args.network_train_unet_only
|
||||
|
||||
@@ -136,11 +143,6 @@ class NetworkTrainer:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype)
|
||||
return encoder_hidden_states
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||
return noise_pred
|
||||
@@ -313,7 +315,7 @@ class NetworkTrainer:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
@@ -437,8 +439,10 @@ class NetworkTrainer:
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
for t_enc in text_encoders:
|
||||
t_enc.gradient_checkpointing_enable()
|
||||
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
if flag:
|
||||
if t_enc.supports_gradient_checkpointing:
|
||||
t_enc.gradient_checkpointing_enable()
|
||||
del t_enc
|
||||
network.enable_gradient_checkpointing() # may have no effect
|
||||
|
||||
@@ -522,14 +526,17 @@ class NetworkTrainer:
|
||||
|
||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||
# Experimental Feature: Put base model into fp8 to save vram
|
||||
if args.fp8_base:
|
||||
if args.fp8_base or args.fp8_base_unet:
|
||||
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||
assert (
|
||||
args.mixed_precision != "no"
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training.")
|
||||
accelerator.print("enable fp8 training for U-Net.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
te_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
if not args.fp8_base_unet:
|
||||
accelerator.print("enable fp8 training for Text Encoder.")
|
||||
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
@@ -546,19 +553,18 @@ class NetworkTrainer:
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(
|
||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||
t_enc.encoder.embeddings.to(
|
||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if args.deepspeed:
|
||||
flags = self.get_text_encoders_train_flags(args, text_encoders)
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||
args,
|
||||
unet=unet if train_unet else None,
|
||||
text_encoder1=text_encoders[0] if train_text_encoder else None,
|
||||
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
||||
text_encoder1=text_encoders[0] if flags[0] else None,
|
||||
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
|
||||
network=network,
|
||||
)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -571,11 +577,14 @@ class NetworkTrainer:
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
text_encoders = [
|
||||
(accelerator.prepare(t_enc) if flag else t_enc)
|
||||
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
|
||||
]
|
||||
if len(text_encoders) > 1:
|
||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||
text_encoder = text_encoders
|
||||
else:
|
||||
text_encoder = accelerator.prepare(text_encoder)
|
||||
text_encoders = [text_encoder]
|
||||
text_encoder = text_encoders[0]
|
||||
else:
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
@@ -587,11 +596,11 @@ class NetworkTrainer:
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
for t_enc in text_encoders:
|
||||
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
t_enc.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if train_text_encoder:
|
||||
if frag:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
else:
|
||||
@@ -736,6 +745,7 @@ class NetworkTrainer:
|
||||
"ss_huber_schedule": args.huber_schedule,
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": args.fp8_base,
|
||||
"ss_fp8_base_unet": args.fp8_base_unet,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -1004,6 +1014,7 @@ class NetworkTrainer:
|
||||
for t_enc in text_encoders:
|
||||
del t_enc
|
||||
text_encoders = []
|
||||
text_encoder = None
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
@@ -1018,7 +1029,7 @@ class NetworkTrainer:
|
||||
# log device and dtype for each model
|
||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||
for t_enc in text_encoders:
|
||||
logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}")
|
||||
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -1073,12 +1084,17 @@ class NetworkTrainer:
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
else:
|
||||
if (
|
||||
text_encoder_conds is None
|
||||
or len(text_encoder_conds) == 0
|
||||
or text_encoder_conds[0] is None
|
||||
or train_text_encoder
|
||||
):
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# SD only
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
tokenizers[0],
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
@@ -1088,13 +1104,18 @@ class NetworkTrainer:
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||
@@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
parser.add_argument(
|
||||
"--fp8_base_unet",
|
||||
action="store_true",
|
||||
help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
|
||||
" / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
||||
|
||||
Reference in New Issue
Block a user