mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
remove unused func
This commit is contained in:
@@ -286,54 +286,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュする
|
|
||||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
|
||||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
|
|
||||||
print("caching text encoder outputs")
|
|
||||||
|
|
||||||
tokenizer1, tokenizer2 = tokenizers
|
|
||||||
text_encoder1, text_encoder2 = text_encoders
|
|
||||||
text_encoder1.to(accelerator.device)
|
|
||||||
text_encoder2.to(accelerator.device)
|
|
||||||
if weight_dtype is not None:
|
|
||||||
text_encoder1.to(dtype=weight_dtype)
|
|
||||||
text_encoder2.to(dtype=weight_dtype)
|
|
||||||
|
|
||||||
text_encoder1_cache = {}
|
|
||||||
text_encoder2_cache = {}
|
|
||||||
for batch in tqdm(dataset):
|
|
||||||
input_ids1_batch = batch["input_ids"].to(accelerator.device)
|
|
||||||
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
|
|
||||||
|
|
||||||
# split batch to avoid OOM
|
|
||||||
# TODO specify batch size by args
|
|
||||||
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
|
||||||
# remove input_ids already in cache
|
|
||||||
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
|
||||||
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
|
||||||
if input_id1_cache_key in text_encoder1_cache:
|
|
||||||
assert input_id2_cache_key in text_encoder2_cache
|
|
||||||
continue
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
|
|
||||||
args,
|
|
||||||
input_id1,
|
|
||||||
input_id2,
|
|
||||||
tokenizer1,
|
|
||||||
tokenizer2,
|
|
||||||
text_encoder1,
|
|
||||||
text_encoder2,
|
|
||||||
None if not args.full_fp16 else weight_dtype,
|
|
||||||
)
|
|
||||||
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
|
|
||||||
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
|
|
||||||
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
|
|
||||||
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
|
|
||||||
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
|
|
||||||
return text_encoder1_cache, text_encoder2_cache
|
|
||||||
|
|
||||||
|
|
||||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||||
|
|||||||
Reference in New Issue
Block a user