mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
FLUX.1 LoRA supports CLIP-L
This commit is contained in:
@@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 27, 2024:
|
||||||
|
|
||||||
|
- FLUX.1 LoRA training now supports CLIP-L LoRA. Please remove `--network_train_unet_only`. T5XXL is not trained. The output of T5XXL is still cached, so `--cache_text_encoder_outputs` or `--cache_text_encoder_outputs_to_disk` is still required. The trained LoRA can be used with ComfyUI.
|
||||||
|
- `flux_extract_lora.py` and `convert_flux_lora.py` do not support CLIP-L LoRA.
|
||||||
|
- `--sigmoid_scale` is now effective even when `--timestep_sampling shift` is specified. Normally, leave it at 1.0. Larger values make the value before shift closer to a uniform distribution.
|
||||||
|
|
||||||
|
- __Experimental__ `--fp8_base_unet` option is added to `flux_train_network.py`. Flux can be trained with fp8, and CLIP-L can be trained with bf16/fp16. When specifying this option, the `--fp8_base` option is not required (Flux is fp8, and CLIP-L is bf16/fp16, regardless of the `--fp8_base` option).
|
||||||
|
|
||||||
Aug 25, 2024:
|
Aug 25, 2024:
|
||||||
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.
|
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.
|
||||||
Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0`
|
Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0`
|
||||||
|
|||||||
@@ -40,9 +40,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
train_dataset_group.is_text_encoder_output_cacheable()
|
train_dataset_group.is_text_encoder_output_cacheable()
|
||||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||||
|
|
||||||
assert (
|
# assert (
|
||||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
# args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
# ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||||
|
if not args.network_train_unet_only:
|
||||||
|
logger.info(
|
||||||
|
"network for CLIP-L only will be trained. T5XXL will not be trained / CLIP-Lのネットワークのみが学習されます。T5XXLは学習されません"
|
||||||
|
)
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
if args.max_token_length is not None:
|
||||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||||
@@ -137,12 +141,25 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
|
||||||
|
|
||||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
|
if args.cache_text_encoder_outputs:
|
||||||
|
if self.is_train_text_encoder(args):
|
||||||
|
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
|
||||||
|
else:
|
||||||
|
return text_encoders # ignored
|
||||||
|
else:
|
||||||
|
return text_encoders # both CLIP-L and T5XXL are needed for encoding
|
||||||
|
|
||||||
|
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||||
|
return [True, False] if self.is_train_text_encoder(args) else [False, False]
|
||||||
|
|
||||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||||
args.cache_text_encoder_outputs_to_disk, None, False, apply_t5_attn_mask=args.apply_t5_attn_mask
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
is_partial=self.is_train_text_encoder(args),
|
||||||
|
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -190,9 +207,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# move back to cpu
|
# move back to cpu
|
||||||
logger.info("move text encoders back to cpu")
|
if not self.is_train_text_encoder(args):
|
||||||
text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
logger.info("move CLIP-L back to cpu")
|
||||||
text_encoders[1].to("cpu") # , dtype=torch.float32)
|
text_encoders[0].to("cpu")
|
||||||
|
logger.info("move t5XXL back to cpu")
|
||||||
|
text_encoders[1].to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
@@ -297,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
noisy_model_input.requires_grad_(True)
|
noisy_model_input.requires_grad_(True)
|
||||||
for t in text_encoder_conds:
|
for t in text_encoder_conds:
|
||||||
|
if t.dtype.is_floating_point:
|
||||||
t.requires_grad_(True)
|
t.requires_grad_(True)
|
||||||
img_ids.requires_grad_(True)
|
img_ids.requires_grad_(True)
|
||||||
guidance_vec.requires_grad_(True)
|
guidance_vec.requires_grad_(True)
|
||||||
@@ -384,7 +404,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||||
|
|
||||||
def is_text_encoder_not_needed_for_training(self, args):
|
def is_text_encoder_not_needed_for_training(self, args):
|
||||||
return args.cache_text_encoder_outputs
|
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ def sample_images(
|
|||||||
|
|
||||||
logger.info("")
|
logger.info("")
|
||||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||||
if not os.path.isfile(args.sample_prompts):
|
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -66,6 +66,7 @@ def sample_images(
|
|||||||
|
|
||||||
# unwrap unet and text_encoder(s)
|
# unwrap unet and text_encoder(s)
|
||||||
flux = accelerator.unwrap_model(flux)
|
flux = accelerator.unwrap_model(flux)
|
||||||
|
if text_encoders is not None:
|
||||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||||
|
|
||||||
@@ -134,7 +135,7 @@ def sample_image_inference(
|
|||||||
accelerator: Accelerator,
|
accelerator: Accelerator,
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
flux: flux_models.Flux,
|
flux: flux_models.Flux,
|
||||||
text_encoders: List[CLIPTextModel],
|
text_encoders: Optional[List[CLIPTextModel]],
|
||||||
ae: flux_models.AutoEncoder,
|
ae: flux_models.AutoEncoder,
|
||||||
save_dir,
|
save_dir,
|
||||||
prompt_dict,
|
prompt_dict,
|
||||||
@@ -387,6 +388,7 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
elif args.timestep_sampling == "shift":
|
elif args.timestep_sampling == "shift":
|
||||||
shift = args.discrete_flow_shift
|
shift = args.discrete_flow_shift
|
||||||
logits_norm = torch.randn(bsz, device=device)
|
logits_norm = torch.randn(bsz, device=device)
|
||||||
|
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||||
timesteps = logits_norm.sigmoid()
|
timesteps = logits_norm.sigmoid()
|
||||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
if apply_t5_attn_mask is None:
|
if apply_t5_attn_mask is None:
|
||||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||||
|
|
||||||
clip_l, t5xxl = models
|
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
|
||||||
l_tokens, t5_tokens = tokens[:2]
|
l_tokens, t5_tokens = tokens[:2]
|
||||||
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
|
||||||
|
|
||||||
@@ -81,6 +81,7 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
|||||||
else:
|
else:
|
||||||
t5_out = None
|
t5_out = None
|
||||||
txt_ids = None
|
txt_ids = None
|
||||||
|
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
|
||||||
|
|
||||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
|
||||||
|
|
||||||
|
|||||||
@@ -421,7 +421,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||||
|
|||||||
@@ -127,8 +127,15 @@ class NetworkTrainer:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
|
"""
|
||||||
|
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
|
||||||
|
"""
|
||||||
return text_encoders
|
return text_encoders
|
||||||
|
|
||||||
|
# returns a list of bool values indicating whether each text encoder should be trained
|
||||||
|
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||||
|
return [True] * len(text_encoders) if self.is_train_text_encoder(args) else [False] * len(text_encoders)
|
||||||
|
|
||||||
def is_train_text_encoder(self, args):
|
def is_train_text_encoder(self, args):
|
||||||
return not args.network_train_unet_only
|
return not args.network_train_unet_only
|
||||||
|
|
||||||
@@ -136,11 +143,6 @@ class NetworkTrainer:
|
|||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
|
||||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype)
|
|
||||||
return encoder_hidden_states
|
|
||||||
|
|
||||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||||
return noise_pred
|
return noise_pred
|
||||||
@@ -437,7 +439,9 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
for t_enc in text_encoders:
|
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||||
|
if flag:
|
||||||
|
if t_enc.supports_gradient_checkpointing:
|
||||||
t_enc.gradient_checkpointing_enable()
|
t_enc.gradient_checkpointing_enable()
|
||||||
del t_enc
|
del t_enc
|
||||||
network.enable_gradient_checkpointing() # may have no effect
|
network.enable_gradient_checkpointing() # may have no effect
|
||||||
@@ -522,14 +526,17 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||||
# Experimental Feature: Put base model into fp8 to save vram
|
# Experimental Feature: Put base model into fp8 to save vram
|
||||||
if args.fp8_base:
|
if args.fp8_base or args.fp8_base_unet:
|
||||||
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||||
assert (
|
assert (
|
||||||
args.mixed_precision != "no"
|
args.mixed_precision != "no"
|
||||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||||
accelerator.print("enable fp8 training.")
|
accelerator.print("enable fp8 training for U-Net.")
|
||||||
unet_weight_dtype = torch.float8_e4m3fn
|
unet_weight_dtype = torch.float8_e4m3fn
|
||||||
te_weight_dtype = torch.float8_e4m3fn
|
|
||||||
|
if not args.fp8_base_unet:
|
||||||
|
accelerator.print("enable fp8 training for Text Encoder.")
|
||||||
|
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||||
|
|
||||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||||
@@ -546,19 +553,18 @@ class NetworkTrainer:
|
|||||||
t_enc.to(dtype=te_weight_dtype)
|
t_enc.to(dtype=te_weight_dtype)
|
||||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||||
# nn.Embedding not support FP8
|
# nn.Embedding not support FP8
|
||||||
t_enc.text_model.embeddings.to(
|
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
|
||||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||||
t_enc.encoder.embeddings.to(
|
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||||
dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
|
flags = self.get_text_encoders_train_flags(args, text_encoders)
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||||
args,
|
args,
|
||||||
unet=unet if train_unet else None,
|
unet=unet if train_unet else None,
|
||||||
text_encoder1=text_encoders[0] if train_text_encoder else None,
|
text_encoder1=text_encoders[0] if flags[0] else None,
|
||||||
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
text_encoder2=(text_encoders[1] if flags[1] else None) if len(text_encoders) > 1 else None,
|
||||||
network=network,
|
network=network,
|
||||||
)
|
)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@@ -571,11 +577,14 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
|
text_encoders = [
|
||||||
|
(accelerator.prepare(t_enc) if flag else t_enc)
|
||||||
|
for t_enc, flag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))
|
||||||
|
]
|
||||||
if len(text_encoders) > 1:
|
if len(text_encoders) > 1:
|
||||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
text_encoder = text_encoders
|
||||||
else:
|
else:
|
||||||
text_encoder = accelerator.prepare(text_encoder)
|
text_encoder = text_encoders[0]
|
||||||
text_encoders = [text_encoder]
|
|
||||||
else:
|
else:
|
||||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||||
|
|
||||||
@@ -587,11 +596,11 @@ class NetworkTrainer:
|
|||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
# according to TI example in Diffusers, train is required
|
# according to TI example in Diffusers, train is required
|
||||||
unet.train()
|
unet.train()
|
||||||
for t_enc in text_encoders:
|
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||||
t_enc.train()
|
t_enc.train()
|
||||||
|
|
||||||
# set top parameter requires_grad = True for gradient checkpointing works
|
# set top parameter requires_grad = True for gradient checkpointing works
|
||||||
if train_text_encoder:
|
if frag:
|
||||||
t_enc.text_model.embeddings.requires_grad_(True)
|
t_enc.text_model.embeddings.requires_grad_(True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -736,6 +745,7 @@ class NetworkTrainer:
|
|||||||
"ss_huber_schedule": args.huber_schedule,
|
"ss_huber_schedule": args.huber_schedule,
|
||||||
"ss_huber_c": args.huber_c,
|
"ss_huber_c": args.huber_c,
|
||||||
"ss_fp8_base": args.fp8_base,
|
"ss_fp8_base": args.fp8_base,
|
||||||
|
"ss_fp8_base_unet": args.fp8_base_unet,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.update_metadata(metadata, args) # architecture specific metadata
|
self.update_metadata(metadata, args) # architecture specific metadata
|
||||||
@@ -1004,6 +1014,7 @@ class NetworkTrainer:
|
|||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
del t_enc
|
del t_enc
|
||||||
text_encoders = []
|
text_encoders = []
|
||||||
|
text_encoder = None
|
||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||||
@@ -1018,7 +1029,7 @@ class NetworkTrainer:
|
|||||||
# log device and dtype for each model
|
# log device and dtype for each model
|
||||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}")
|
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
|
||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -1073,12 +1084,17 @@ class NetworkTrainer:
|
|||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||||
else:
|
if (
|
||||||
|
text_encoder_conds is None
|
||||||
|
or len(text_encoder_conds) == 0
|
||||||
|
or text_encoder_conds[0] is None
|
||||||
|
or train_text_encoder
|
||||||
|
):
|
||||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
# SD only
|
# SD only
|
||||||
text_encoder_conds = get_weighted_text_embeddings(
|
encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||||
tokenizers[0],
|
tokenizers[0],
|
||||||
text_encoder,
|
text_encoder,
|
||||||
batch["captions"],
|
batch["captions"],
|
||||||
@@ -1088,13 +1104,18 @@ class NetworkTrainer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||||
input_ids,
|
input_ids,
|
||||||
)
|
)
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||||
|
|
||||||
|
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||||
|
for i in range(len(encoded_text_encoder_conds)):
|
||||||
|
if encoded_text_encoder_conds[i] is not None:
|
||||||
|
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||||
|
|
||||||
# sample noise, call unet, get target
|
# sample noise, call unet, get target
|
||||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||||
@@ -1257,6 +1278,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp8_base_unet",
|
||||||
|
action="store_true",
|
||||||
|
help="use fp8 for U-Net (or DiT), Text Encoder is fp16 or bf16"
|
||||||
|
" / U-Net(またはDiT)にfp8を使用する。Text Encoderはfp16またはbf16",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
||||||
|
|||||||
Reference in New Issue
Block a user