mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix to work dropout_rate for TEs
This commit is contained in:
@@ -300,7 +300,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t.dtype.is_floating_point:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
@@ -415,13 +415,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
# # drop cached text encoder outputs
|
||||
# text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
# if text_encoder_outputs_list is not None:
|
||||
# text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
# text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
# batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
pass
|
||||
# drop cached text encoder outputs
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
Reference in New Issue
Block a user