diff --git a/gen_img.py b/gen_img.py index eba47805..13d49c33 100644 --- a/gen_img.py +++ b/gen_img.py @@ -22,6 +22,7 @@ import numpy as np import torch from library.device_utils import init_ipex +from library.strategy_sd import SdTokenizeStrategy init_ipex() @@ -1698,7 +1699,8 @@ def main(args): tokenizers = [tokenizer1, tokenizer2] else: if use_stable_diffusion_format: - tokenizer = train_util.load_tokenizer(args) + tokenize_strategy = SdTokenizeStrategy(args.v2, max_length=None, tokenizer_cache_dir=args.tokenizer_cache_dir) + tokenizer = tokenize_strategy.tokenizer tokenizers = [tokenizer] # schedulerを用意する @@ -1960,7 +1962,7 @@ def main(args): if not is_sdxl: for i, model in enumerate(args.control_net_models): prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + weight = 1.0 if not args.control_net_multipliers or len(args.control_net_multipliers) <= i else args.control_net_multipliers[i] ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)