mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support tokenizer caching for offline training/gen
This commit is contained in:
@@ -80,6 +80,7 @@ from PIL import Image
|
|||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
import library.train_util as train_util
|
||||||
import tools.original_control_net as original_control_net
|
import tools.original_control_net as original_control_net
|
||||||
from tools.original_control_net import ControlNetInfo
|
from tools.original_control_net import ControlNetInfo
|
||||||
|
|
||||||
@@ -1963,10 +1964,7 @@ def main(args):
|
|||||||
# tokenizerを読み込む
|
# tokenizerを読み込む
|
||||||
print("loading tokenizer")
|
print("loading tokenizer")
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
if args.v2:
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
|
||||||
else:
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
|
||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
sched_init_args = {}
|
sched_init_args = {}
|
||||||
@@ -2715,6 +2713,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
||||||
parser.add_argument("--vae", type=str, default=None,
|
parser.add_argument("--vae", type=str, default=None,
|
||||||
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
||||||
|
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
||||||
|
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
||||||
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
# parser.add_argument("--replace_clip_l14_336", action='store_true',
|
||||||
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
||||||
parser.add_argument("--seed", type=int, default=None,
|
parser.add_argument("--seed", type=int, default=None,
|
||||||
|
|||||||
@@ -1369,6 +1369,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
|||||||
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
||||||
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
||||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
||||||
|
parser.add_argument("--tokenizer_cache_dir", type=str, default=None,
|
||||||
|
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)")
|
||||||
|
|
||||||
|
|
||||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||||
@@ -1796,12 +1798,28 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
|||||||
|
|
||||||
def load_tokenizer(args: argparse.Namespace):
|
def load_tokenizer(args: argparse.Namespace):
|
||||||
print("prepare tokenizer")
|
print("prepare tokenizer")
|
||||||
if args.v2:
|
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
|
||||||
else:
|
tokenizer: CLIPTokenizer = None
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
if args.tokenizer_cache_dir:
|
||||||
if args.max_token_length is not None:
|
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace('/', '_'))
|
||||||
|
if os.path.exists(local_tokenizer_path):
|
||||||
|
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||||||
|
|
||||||
|
if tokenizer is None:
|
||||||
|
if args.v2:
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
||||||
|
else:
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||||
|
|
||||||
|
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||||
print(f"update token length: {args.max_token_length}")
|
print(f"update token length: {args.max_token_length}")
|
||||||
|
|
||||||
|
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||||
|
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||||
|
tokenizer.save_pretrained(local_tokenizer_path)
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user