Fix SD3 LoRA training to work (WIP)

This commit is contained in:
kohya-ss
2024-10-27 17:03:36 +09:00
parent db2b4d41b9
commit a1255d637f
3 changed files with 38 additions and 17 deletions

View File

@@ -125,7 +125,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5xxl_dropout_rate,
args.t5_dropout_rate,
)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
@@ -415,12 +415,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
# # drop cached text encoder outputs
# text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
# if text_encoder_outputs_list is not None:
# text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
# text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
# batch["text_encoder_outputs_list"] = text_encoder_outputs_list
pass
def setup_parser() -> argparse.ArgumentParser: