mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support multi gpu in caching text encoder outputs
This commit is contained in:
@@ -319,7 +319,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype):
|
||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
|
||||
print("caching text encoder outputs")
|
||||
|
||||
tokenizer1, tokenizer2 = tokenizers
|
||||
@@ -332,9 +332,9 @@ def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dat
|
||||
|
||||
text_encoder1_cache = {}
|
||||
text_encoder2_cache = {}
|
||||
for batch in tqdm(data_loader):
|
||||
input_ids1_batch = batch["input_ids"]
|
||||
input_ids2_batch = batch["input_ids2"]
|
||||
for batch in tqdm(dataset):
|
||||
input_ids1_batch = batch["input_ids"].to(accelerator.device)
|
||||
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
|
||||
|
||||
# split batch to avoid OOM
|
||||
# TODO specify batch size by args
|
||||
|
||||
Reference in New Issue
Block a user