This commit is contained in:
alex choi
2025-09-30 09:57:02 +05:30
committed by GitHub
4 changed files with 39 additions and 2 deletions

View File

@@ -137,6 +137,20 @@ def _load_target_model(
def load_tokenizers(args: argparse.Namespace):
logger.info("prepare tokenizers")
# load diffusers tokenizers if available
name_or_path = args.pretrained_model_name_or_path
if os.path.isdir(name_or_path):
tokenizer_path = os.path.join(name_or_path, "tokenizer")
tokenizer_2_path = os.path.join(name_or_path, "tokenizer_2")
if os.path.exists(tokenizer_path) \
and os.path.exists(tokenizer_2_path):
logger.info(f"load tokenizers from pretrained_model_name_or_path: {name_or_path}")
tokeniers = [
CLIPTokenizer.from_pretrained(tokenizer_path),
CLIPTokenizer.from_pretrained(tokenizer_2_path),
]
return tokeniers
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = []
for i, original_path in enumerate(original_paths):

View File

@@ -1108,7 +1108,7 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info("caching text encoder outputs.")
image_infos = list(self.image_data.values())
logger.info("checking cache existence...")
logger.info("checking cache validity...")
image_infos_to_cache = []
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
@@ -1119,7 +1119,9 @@ class BaseDataset(torch.utils.data.Dataset):
if not is_main_process: # store to info only
continue
if os.path.exists(te_out_npz):
is_cache_valid = is_disk_cached_text_encoder_output_valid(te_out_npz)
if is_cache_valid:
continue
image_infos_to_cache.append(info)
@@ -2270,6 +2272,23 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
return True
def is_disk_cached_text_encoder_output_valid(npz_path: str):
if not os.path.exists(npz_path):
logger.debug(f'is_disk_cached_text_encoder_output_valid file not found: {npz_path}')
return False
try:
hidden_state1, hidden_state2, pool2 = load_text_encoder_outputs_from_disk(npz_path)
if hidden_state1 is None or hidden_state2 is None or pool2 is None:
logger.debug(f'is_disk_cached_text_encoder_output_valid None value found: hidden_state1 {hidden_state1}, hidden_state2 {hidden_state2}, pool2 {pool2}')
return False
return True
except Exception as e:
logger.debug(f"is_disk_cached_text_encoder_output_valid failed to load text encoder outputs from {npz_path}. {e}")
return False
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,

View File

@@ -12,6 +12,7 @@ from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library import deepspeed_utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -181,6 +182,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
deepspeed_utils.add_deepspeed_arguments(parser)
config_util.add_config_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument(

View File

@@ -11,6 +11,7 @@ from tqdm import tqdm
from library import config_util
from library import train_util
from library import deepspeed_utils
from library import sdxl_train_util
from library.config_util import (
ConfigSanitizer,
@@ -178,6 +179,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument(