mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support Diffusers format
This commit is contained in:
@@ -1288,38 +1288,9 @@ def main(args):
|
||||
if len(files) == 1:
|
||||
args.ckpt = files[0]
|
||||
|
||||
use_stable_diffusion_format = os.path.isfile(args.ckpt)
|
||||
assert use_stable_diffusion_format, "Diffusers pretrained models are not supported yet"
|
||||
print("load StableDiffusion checkpoint")
|
||||
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt, "cpu"
|
||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype
|
||||
)
|
||||
# else:
|
||||
# print("load Diffusers pretrained models")
|
||||
# TODO use Diffusers 0.18.1 and support SDXL pipeline
|
||||
# raise NotImplementedError("Diffusers pretrained models are not supported yet")
|
||||
# loading_pipe = StableDiffusionXLPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
||||
# text_encoder = loading_pipe.text_encoder
|
||||
# vae = loading_pipe.vae
|
||||
# unet = loading_pipe.unet
|
||||
# tokenizer = loading_pipe.tokenizer
|
||||
# del loading_pipe
|
||||
|
||||
# # Diffusers U-Net to original U-Net
|
||||
# original_unet = SdxlUNet2DConditionModel(
|
||||
# unet.config.sample_size,
|
||||
# unet.config.attention_head_dim,
|
||||
# unet.config.cross_attention_dim,
|
||||
# unet.config.use_linear_projection,
|
||||
# unet.config.upcast_attention,
|
||||
# )
|
||||
# original_unet.load_state_dict(unet.state_dict())
|
||||
# unet = original_unet
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
vae = model_util.load_vae(args.vae, dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
# xformers、Hypernetwork対応
|
||||
if not args.diffusers_xformers:
|
||||
@@ -1329,7 +1300,6 @@ def main(args):
|
||||
|
||||
# tokenizerを読み込む
|
||||
print("loading tokenizer")
|
||||
if use_stable_diffusion_format:
|
||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||
|
||||
# schedulerを用意する
|
||||
|
||||
Reference in New Issue
Block a user