mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encoders LoRA not trained
This commit is contained in:
@@ -272,6 +272,9 @@ class NetworkTrainer:
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
|
||||
def train(self, args):
|
||||
@@ -1030,9 +1033,9 @@ class NetworkTrainer:
|
||||
|
||||
# callback for step start
|
||||
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
||||
on_step_start = accelerator.unwrap_model(network).on_step_start
|
||||
on_step_start_for_network = accelerator.unwrap_model(network).on_step_start
|
||||
else:
|
||||
on_step_start = lambda *args, **kwargs: None
|
||||
on_step_start_for_network = lambda *args, **kwargs: None
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
||||
@@ -1113,7 +1116,10 @@ class NetworkTrainer:
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
@@ -1146,6 +1152,7 @@ class NetworkTrainer:
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
|
||||
Reference in New Issue
Block a user