From 7ec9a7af798b71631281363bd80520559b2f26f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 23 Jul 2023 13:33:14 +0900 Subject: [PATCH] support Diffusers format --- sdxl_gen_img.py | 36 +++--------------------------------- 1 file changed, 3 insertions(+), 33 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 6770c720..b5e26105 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1288,38 +1288,9 @@ def main(args): if len(files) == 1: args.ckpt = files[0] - use_stable_diffusion_format = os.path.isfile(args.ckpt) - assert use_stable_diffusion_format, "Diffusers pretrained models are not supported yet" - print("load StableDiffusion checkpoint") - text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt, "cpu" + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype ) - # else: - # print("load Diffusers pretrained models") - # TODO use Diffusers 0.18.1 and support SDXL pipeline - # raise NotImplementedError("Diffusers pretrained models are not supported yet") - # loading_pipe = StableDiffusionXLPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) - # text_encoder = loading_pipe.text_encoder - # vae = loading_pipe.vae - # unet = loading_pipe.unet - # tokenizer = loading_pipe.tokenizer - # del loading_pipe - - # # Diffusers U-Net to original U-Net - # original_unet = SdxlUNet2DConditionModel( - # unet.config.sample_size, - # unet.config.attention_head_dim, - # unet.config.cross_attention_dim, - # unet.config.use_linear_projection, - # unet.config.upcast_attention, - # ) - # original_unet.load_state_dict(unet.state_dict()) - # unet = original_unet - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, dtype) - print("additional VAE loaded") # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1329,8 +1300,7 @@ def main(args): # tokenizerを読み込む print("loading tokenizer") - if use_stable_diffusion_format: - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) # schedulerを用意する sched_init_args = {}