From b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 25 Jul 2023 19:07:26 +0900 Subject: [PATCH] remove unused func --- library/sdxl_train_util.py | 48 -------------------------------------- 1 file changed, 48 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 54774f88..6ff0d48f 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -286,54 +286,6 @@ def save_sd_model_on_epoch_end_or_stepwise( ) -# TextEncoderの出力をキャッシュする -# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる -def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype): - print("caching text encoder outputs") - - tokenizer1, tokenizer2 = tokenizers - text_encoder1, text_encoder2 = text_encoders - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - if weight_dtype is not None: - text_encoder1.to(dtype=weight_dtype) - text_encoder2.to(dtype=weight_dtype) - - text_encoder1_cache = {} - text_encoder2_cache = {} - for batch in tqdm(dataset): - input_ids1_batch = batch["input_ids"].to(accelerator.device) - input_ids2_batch = batch["input_ids2"].to(accelerator.device) - - # split batch to avoid OOM - # TODO specify batch size by args - for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)): - # remove input_ids already in cache - input_id1_cache_key = tuple(input_id1.flatten().tolist()) - input_id2_cache_key = tuple(input_id2.flatten().tolist()) - if input_id1_cache_key in text_encoder1_cache: - assert input_id2_cache_key in text_encoder2_cache - continue - - with torch.no_grad(): - encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states( - args, - input_id1, - input_id2, - tokenizer1, - tokenizer2, - text_encoder1, - text_encoder2, - None if not args.full_fp16 else weight_dtype, - ) - encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768 - encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280 - pool2 = pool2.detach().to("cpu").squeeze(0) # 1280 - text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1 - text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2) - return text_encoder1_cache, text_encoder2_cache - - def add_sdxl_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"