support tokenizer caching for offline training/gen

This commit is contained in:
Kohya S
2023-02-25 18:46:59 +09:00
parent 9993792656
commit a28f9ae7a3
2 changed files with 27 additions and 9 deletions

View File

@@ -1369,6 +1369,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため")
def add_optimizer_arguments(parser: argparse.ArgumentParser):
@@ -1796,12 +1798,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
def load_tokenizer(args: argparse.Namespace):
print("prepare tokenizer")
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
if args.max_token_length is not None:
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer