Better implementation for te autocast (#895)

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* Better cache TE and TE lr

* Fix with list

* Add timeout settings

* Fix arg style
This commit is contained in:
Kohaku-Blueleaf
2023-10-28 14:49:59 +08:00
committed by GitHub
parent 202f2c3292
commit 1cefb2a753
4 changed files with 41 additions and 30 deletions

View File

@@ -109,6 +109,9 @@ class NetworkTrainer:
def is_text_encoder_outputs_cached(self, args):
return False
def is_train_text_encoder(self, args):
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
@@ -310,7 +313,7 @@ class NetworkTrainer:
args.scale_weight_norms = False
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
@@ -403,6 +406,8 @@ class NetworkTrainer:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
elif train_text_encoder:
if len(text_encoders) > 1:
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -767,7 +772,7 @@ class NetworkTrainer:
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
with torch.set_grad_enabled(train_text_encoder):
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(