support Diffusers format

This commit is contained in:
Kohya S
2023-07-23 13:33:14 +09:00
parent 50b53e183e
commit 7ec9a7af79

View File

@@ -1288,38 +1288,9 @@ def main(args):
if len(files) == 1: if len(files) == 1:
args.ckpt = files[0] args.ckpt = files[0]
use_stable_diffusion_format = os.path.isfile(args.ckpt) (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
assert use_stable_diffusion_format, "Diffusers pretrained models are not supported yet" args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype
print("load StableDiffusion checkpoint")
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt, "cpu"
) )
# else:
# print("load Diffusers pretrained models")
# TODO use Diffusers 0.18.1 and support SDXL pipeline
# raise NotImplementedError("Diffusers pretrained models are not supported yet")
# loading_pipe = StableDiffusionXLPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
# text_encoder = loading_pipe.text_encoder
# vae = loading_pipe.vae
# unet = loading_pipe.unet
# tokenizer = loading_pipe.tokenizer
# del loading_pipe
# # Diffusers U-Net to original U-Net
# original_unet = SdxlUNet2DConditionModel(
# unet.config.sample_size,
# unet.config.attention_head_dim,
# unet.config.cross_attention_dim,
# unet.config.use_linear_projection,
# unet.config.upcast_attention,
# )
# original_unet.load_state_dict(unet.state_dict())
# unet = original_unet
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)
print("additional VAE loaded")
# xformers、Hypernetwork対応 # xformers、Hypernetwork対応
if not args.diffusers_xformers: if not args.diffusers_xformers:
@@ -1329,8 +1300,7 @@ def main(args):
# tokenizerを読み込む # tokenizerを読み込む
print("loading tokenizer") print("loading tokenizer")
if use_stable_diffusion_format: tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# schedulerを用意する # schedulerを用意する
sched_init_args = {} sched_init_args = {}