diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index f049e8a2..6bab0bb8 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -80,6 +80,7 @@ from PIL import Image from PIL.PngImagePlugin import PngInfo import library.model_util as model_util +import library.train_util as train_util import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo @@ -1963,10 +1964,7 @@ def main(args): # tokenizerを読み込む print("loading tokenizer") if use_stable_diffusion_format: - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2) + tokenizer = train_util.load_tokenizer(args) # schedulerを用意する sched_init_args = {} @@ -2715,6 +2713,8 @@ if __name__ == '__main__': parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") parser.add_argument("--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + parser.add_argument("--tokenizer_cache_dir", type=str, default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") # parser.add_argument("--replace_clip_l14_336", action='store_true', # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") parser.add_argument("--seed", type=int, default=None, diff --git a/library/train_util.py b/library/train_util.py index a02207b4..9f13baf2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1369,6 +1369,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): help='enable v-parameterization training / v-parameterization学習を有効にする') parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") + parser.add_argument("--tokenizer_cache_dir", type=str, default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)") def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -1796,12 +1798,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): def load_tokenizer(args: argparse.Namespace): print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) - if args.max_token_length is not None: + original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH + + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_')) + if os.path.exists(local_tokenizer_path): + print(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2 + + if tokenizer is None: + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: print(f"update token length: {args.max_token_length}") + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + print(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + return tokenizer