T5XXL LoRA training, fp8 T5XXL support

This commit is contained in:
Kohya S
2024-09-04 21:33:17 +09:00
parent 6abacf04da
commit b65ae9b439
7 changed files with 222 additions and 67 deletions

View File

@@ -157,6 +157,9 @@ class NetworkTrainer:
# region SD/SDXL
def post_process_network(self, args, accelerator, network, text_encoders, unet):
pass
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
@@ -237,6 +240,13 @@ class NetworkTrainer:
def is_text_encoder_not_needed_for_training(self, args):
return False # use for sample images
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
# endregion
def train(self, args):
@@ -329,7 +339,7 @@ class NetworkTrainer:
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group)
self.assert_extra_args(args, train_dataset_group) # may change some args
# acceleratorを準備する
logger.info("preparing accelerator")
@@ -428,12 +438,15 @@ class NetworkTrainer:
)
args.scale_weight_norms = False
self.post_process_network(args, accelerator, network, text_encoders, unet)
# apply network to unet and text_encoder
train_unet = not args.network_train_text_encoder_only
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
# FIXME consider alpha of weights
# FIXME consider alpha of weights: this assumes that the alpha is not changed
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")
@@ -533,7 +546,7 @@ class NetworkTrainer:
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training for U-Net.")
unet_weight_dtype = torch.float8_e4m3fn
if not args.fp8_base_unet:
accelerator.print("enable fp8 training for Text Encoder.")
te_weight_dtype = weight_dtype if args.fp8_base_unet else torch.float8_e4m3fn
@@ -545,17 +558,16 @@ class NetworkTrainer:
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
for i, t_enc in enumerate(text_encoders):
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
if hasattr(t_enc, "text_model") and hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
elif hasattr(t_enc, "encoder") and hasattr(t_enc.encoder, "embeddings"):
t_enc.encoder.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# nn.Embedding not support FP8
if te_weight_dtype != weight_dtype:
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
@@ -596,12 +608,12 @@ class NetworkTrainer:
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc, frag in zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders)):
for i, (t_enc, frag) in enumerate(zip(text_encoders, self.get_text_encoders_train_flags(args, text_encoders))):
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if frag:
t_enc.text_model.embeddings.requires_grad_(True)
self.prepare_text_encoder_grad_ckpt_workaround(i, t_enc)
else:
unet.eval()
@@ -1028,8 +1040,12 @@ class NetworkTrainer:
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
for t_enc in text_encoders:
logger.info(f"text_encoder dtype: {t_enc.dtype}, device: {t_enc.device}")
for i, t_enc in enumerate(text_encoders):
params_itr = t_enc.parameters()
params_itr.__next__() # skip the first parameter
params_itr.__next__() # skip the second parameter. because CLIP first two parameters are embeddings
param_3rd = params_itr.__next__()
logger.info(f"text_encoder [{i}] dtype: {param_3rd.dtype}, device: {t_enc.device}")
clean_memory_on_device(accelerator.device)
@@ -1085,11 +1101,7 @@ class NetworkTrainer:
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
if (
len(text_encoder_conds) == 0
or text_encoder_conds[0] is None
or train_text_encoder
):
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions: