mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support Diffusers format
This commit is contained in:
@@ -1288,38 +1288,9 @@ def main(args):
|
|||||||
if len(files) == 1:
|
if len(files) == 1:
|
||||||
args.ckpt = files[0]
|
args.ckpt = files[0]
|
||||||
|
|
||||||
use_stable_diffusion_format = os.path.isfile(args.ckpt)
|
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||||
assert use_stable_diffusion_format, "Diffusers pretrained models are not supported yet"
|
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype
|
||||||
print("load StableDiffusion checkpoint")
|
|
||||||
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
|
||||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt, "cpu"
|
|
||||||
)
|
)
|
||||||
# else:
|
|
||||||
# print("load Diffusers pretrained models")
|
|
||||||
# TODO use Diffusers 0.18.1 and support SDXL pipeline
|
|
||||||
# raise NotImplementedError("Diffusers pretrained models are not supported yet")
|
|
||||||
# loading_pipe = StableDiffusionXLPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
|
||||||
# text_encoder = loading_pipe.text_encoder
|
|
||||||
# vae = loading_pipe.vae
|
|
||||||
# unet = loading_pipe.unet
|
|
||||||
# tokenizer = loading_pipe.tokenizer
|
|
||||||
# del loading_pipe
|
|
||||||
|
|
||||||
# # Diffusers U-Net to original U-Net
|
|
||||||
# original_unet = SdxlUNet2DConditionModel(
|
|
||||||
# unet.config.sample_size,
|
|
||||||
# unet.config.attention_head_dim,
|
|
||||||
# unet.config.cross_attention_dim,
|
|
||||||
# unet.config.use_linear_projection,
|
|
||||||
# unet.config.upcast_attention,
|
|
||||||
# )
|
|
||||||
# original_unet.load_state_dict(unet.state_dict())
|
|
||||||
# unet = original_unet
|
|
||||||
|
|
||||||
# VAEを読み込む
|
|
||||||
if args.vae is not None:
|
|
||||||
vae = model_util.load_vae(args.vae, dtype)
|
|
||||||
print("additional VAE loaded")
|
|
||||||
|
|
||||||
# xformers、Hypernetwork対応
|
# xformers、Hypernetwork対応
|
||||||
if not args.diffusers_xformers:
|
if not args.diffusers_xformers:
|
||||||
@@ -1329,7 +1300,6 @@ def main(args):
|
|||||||
|
|
||||||
# tokenizerを読み込む
|
# tokenizerを読み込む
|
||||||
print("loading tokenizer")
|
print("loading tokenizer")
|
||||||
if use_stable_diffusion_format:
|
|
||||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||||
|
|
||||||
# schedulerを用意する
|
# schedulerを用意する
|
||||||
|
|||||||
Reference in New Issue
Block a user