support tokenizer caching for offline training/gen

This commit is contained in:
Kohya S
2023-02-25 18:46:59 +09:00
parent 9993792656
commit a28f9ae7a3
2 changed files with 27 additions and 9 deletions

View File

@@ -80,6 +80,7 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util
import library.train_util as train_util
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
@@ -1963,10 +1964,7 @@ def main(args):
# tokenizerを読み込む
print("loading tokenizer")
if use_stable_diffusion_format:
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
tokenizer = train_util.load_tokenizer(args)
# schedulerを用意する
sched_init_args = {}
@@ -2715,6 +2713,8 @@ if __name__ == '__main__':
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
parser.add_argument("--vae", type=str, default=None,
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため")
# parser.add_argument("--replace_clip_l14_336", action='store_true',
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
parser.add_argument("--seed", type=int, default=None,