mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
T5XXL LoRA training, fp8 T5XXL support
This commit is contained in:
@@ -157,6 +157,9 @@ class NetworkTrainer:
|
||||
|
||||
# region SD/SDXL
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
pass
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
@@ -237,6 +240,13 @@ class NetworkTrainer:
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return False # use for sample images
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
|
||||
# endregion
|
||||
|
||||
def train(self, args):
|
||||
@@ -329,7 +339,7 @@ class NetworkTrainer:
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
self.assert_extra_args(args, train_dataset_group) # may change some args
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
@@ -428,12 +438,15 @@ class NetworkTrainer:
|
||||
)
|
||||
args.scale_weight_norms = False
|
||||
|
||||
self.post_process_network(args, accelerator, network, text_encoders, unet)
|
||||
|
||||
# apply network to unet and text_encoder
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = self.is_train_text_encoder(args)
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
|
||||
if args.network_weights is not None:
|
||||
# FIXME consider alpha of weights
|
||||
# FIXME consider alpha of weights: this assumes that the alpha is not changed
|
||||
info = network.load_weights(args.network_weights)
|
||||
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
||||
|
||||
@@ -533,7 +546,7 @@ class NetworkTrainer:
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training for U-Net.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
|
||||
if not args.fp8_base_unet:
|
||||
accelerator.print("enable fp8 training for Text Encoder.")
|
||||
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
|
||||
@@ -545,17 +558,16 @@ class NetworkTrainer:
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
for i, t_enc in enumerate(text_encoders):
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||
if t_enc.device.type != "cpu":
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
|
||||
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# nn.Embedding not support FP8
|
||||
if te_weight_dtype != weight_dtype:
|
||||
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if args.deepspeed:
|
||||
@@ -596,12 +608,12 @@ class NetworkTrainer:
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
|
||||
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
|
||||
t_enc.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if frag:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
|
||||
|
||||
else:
|
||||
unet.eval()
|
||||
@@ -1028,8 +1040,12 @@ class NetworkTrainer:
|
||||
|
||||
# log device and dtype for each model
|
||||
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
|
||||
for t_enc in text_encoders:
|
||||
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
|
||||
for i, t_enc in enumerate(text_encoders):
|
||||
params_itr = t_enc.parameters()
|
||||
params_itr.__next__() # skip the first parameter
|
||||
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
|
||||
param_3rd = params_itr.__next__()
|
||||
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -1085,11 +1101,7 @@ class NetworkTrainer:
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
if (
|
||||
len(text_encoder_conds) == 0
|
||||
or text_encoder_conds[0] is None
|
||||
or train_text_encoder
|
||||
):
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
|
||||
Reference in New Issue
Block a user