Better implementation for te autocast (#895)

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* Better cache TE and TE lr

* Fix with list

* Add timeout settings

* Fix arg style
This commit is contained in:
Kohaku-Blueleaf
2023-10-28 14:49:59 +08:00
committed by GitHub
parent 202f2c3292
commit 1cefb2a753
4 changed files with 41 additions and 30 deletions

View File

@@ -70,14 +70,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if torch.cuda.is_available():
torch.cuda.empty_cache()
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)