mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Merge 6789561e0c into 206adb6438
This commit is contained in:
@@ -137,6 +137,20 @@ def _load_target_model(
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
logger.info("prepare tokenizers")
|
||||
|
||||
# load diffusers tokenizers if available
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
if os.path.isdir(name_or_path):
|
||||
tokenizer_path = os.path.join(name_or_path, "tokenizer")
|
||||
tokenizer_2_path = os.path.join(name_or_path, "tokenizer_2")
|
||||
if os.path.exists(tokenizer_path) \
|
||||
and os.path.exists(tokenizer_2_path):
|
||||
logger.info(f"load tokenizers from pretrained_model_name_or_path: {name_or_path}")
|
||||
tokeniers = [
|
||||
CLIPTokenizer.from_pretrained(tokenizer_path),
|
||||
CLIPTokenizer.from_pretrained(tokenizer_2_path),
|
||||
]
|
||||
return tokeniers
|
||||
|
||||
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
||||
tokeniers = []
|
||||
for i, original_path in enumerate(original_paths):
|
||||
|
||||
@@ -1108,7 +1108,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info("caching text encoder outputs.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
logger.info("checking cache validity...")
|
||||
image_infos_to_cache = []
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
@@ -1119,7 +1119,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
if os.path.exists(te_out_npz):
|
||||
is_cache_valid = is_disk_cached_text_encoder_output_valid(te_out_npz)
|
||||
|
||||
if is_cache_valid:
|
||||
continue
|
||||
|
||||
image_infos_to_cache.append(info)
|
||||
@@ -2270,6 +2272,23 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
||||
return True
|
||||
|
||||
|
||||
def is_disk_cached_text_encoder_output_valid(npz_path: str):
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
logger.debug(f'is_disk_cached_text_encoder_output_valid file not found: {npz_path}')
|
||||
return False
|
||||
|
||||
try:
|
||||
hidden_state1, hidden_state2, pool2 = load_text_encoder_outputs_from_disk(npz_path)
|
||||
if hidden_state1 is None or hidden_state2 is None or pool2 is None:
|
||||
logger.debug(f'is_disk_cached_text_encoder_output_valid None value found: hidden_state1 {hidden_state1}, hidden_state2 {hidden_state2}, pool2 {pool2}')
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"is_disk_cached_text_encoder_output_valid failed to load text encoder outputs from {npz_path}. {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||
def load_latents_from_disk(
|
||||
npz_path,
|
||||
|
||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
from library import deepspeed_utils
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
@@ -181,6 +182,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||
parser.add_argument(
|
||||
|
||||
@@ -11,6 +11,7 @@ from tqdm import tqdm
|
||||
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library import deepspeed_utils
|
||||
from library import sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
@@ -178,6 +179,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user