Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encoders LoRA not trained

This commit is contained in:
Kohya S
2024-10-27 16:42:58 +09:00
parent b649bbf2b6
commit db2b4d41b9
5 changed files with 138 additions and 17 deletions

View File

@@ -120,7 +120,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sd3.Sd3TextEncodingStrategy(args.apply_lg_attn_mask, args.apply_t5_attn_mask)
return strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5xxl_dropout_rate,
)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
@@ -408,6 +414,14 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()