remove unused func

This commit is contained in:
Kohya S
2023-07-25 19:07:26 +09:00
parent 2b969e9c42
commit b78c0e2a69

View File

@@ -286,54 +286,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
) )
# TextEncoderの出力をキャッシュする
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
print("caching text encoder outputs")
tokenizer1, tokenizer2 = tokenizers
text_encoder1, text_encoder2 = text_encoders
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if weight_dtype is not None:
text_encoder1.to(dtype=weight_dtype)
text_encoder2.to(dtype=weight_dtype)
text_encoder1_cache = {}
text_encoder2_cache = {}
for batch in tqdm(dataset):
input_ids1_batch = batch["input_ids"].to(accelerator.device)
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
# split batch to avoid OOM
# TODO specify batch size by args
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
# remove input_ids already in cache
input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_id2_cache_key = tuple(input_id2.flatten().tolist())
if input_id1_cache_key in text_encoder1_cache:
assert input_id2_cache_key in text_encoder2_cache
continue
with torch.no_grad():
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
args,
input_id1,
input_id2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
return text_encoder1_cache, text_encoder2_cache
def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする" "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"