support multi gpu in caching text encoder outputs

This commit is contained in:
Kohya S
2023-07-09 16:02:56 +09:00
parent 3579b4570f
commit 0416f26a76
5 changed files with 32 additions and 22 deletions

View File

@@ -319,7 +319,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
# TextEncoderの出力をキャッシュする
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype):
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
print("caching text encoder outputs")
tokenizer1, tokenizer2 = tokenizers
@@ -332,9 +332,9 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
text_encoder1_cache = {}
text_encoder2_cache = {}
for batch in tqdm(data_loader):
input_ids1_batch = batch["input_ids"]
input_ids2_batch = batch["input_ids2"]
for batch in tqdm(dataset):
input_ids1_batch = batch["input_ids"].to(accelerator.device)
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
# split batch to avoid OOM
# TODO specify batch size by args