diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 41a05e95..a7f03afe 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,7 +1,7 @@ import torch from safetensors.torch import load_file, save_file from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer -from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel from library import model_util from library import sdxl_original_unet @@ -486,6 +486,8 @@ def save_stable_diffusion_checkpoint( def save_diffusers_checkpoint( output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None ): + from diffusers import StableDiffusionXLPipeline + # convert U-Net unet_sd = unet.state_dict() du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 34312afc..54774f88 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -2,13 +2,10 @@ import argparse import gc import math import os -from types import SimpleNamespace -from typing import Any +from typing import Optional import torch from tqdm import tqdm from transformers import CLIPTokenizer -import open_clip -from diffusers import StableDiffusionXLPipeline from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline @@ -18,7 +15,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" DEFAULT_NOISE_OFFSET = 0.0357 -# TODO: separate checkpoint for each U-Net/Text Encoder/VAE def load_target_model(args, accelerator, model_version: str, weight_dtype): # load models for each process for pi in range(accelerator.state.num_processes): @@ -33,7 +29,13 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet, logit_scale, ckpt_info, - ) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu") + ) = _load_target_model( + args.pretrained_model_name_or_path, + args.vae, + model_version, + weight_dtype, + accelerator.device if args.lowram else "cpu", + ) # work on low-ram device if args.lowram: @@ -51,8 +53,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info -def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"): - name_or_path = args.pretrained_model_name_or_path +def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"): name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers @@ -68,6 +69,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device) else: # Diffusers model is loaded to CPU + from diffusers import StableDiffusionXLPipeline + variant = "fp16" if weight_dtype == torch.float16 else None print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: @@ -102,8 +105,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp ckpt_info = None # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) + if vae_path is not None: + vae = model_util.load_vae(vae_path, weight_dtype) print("additional VAE loaded") return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index ce4b5959..b918e56b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -65,8 +65,8 @@ import safetensors.torch from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import library.model_util as model_util import library.huggingface_util as huggingface_util -from library.attention_processors import FlashAttnProcessor -from library.hypernetwork import replace_attentions_for_hypernetwork +# from library.attention_processors import FlashAttnProcessor +# from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う @@ -1884,8 +1884,7 @@ def load_latents_from_disk( ) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: npz = np.load(npz_path) if "latents" not in npz: - print(f"error: npz is old format. please re-generate {npz_path}") - return None, None, None, None + raise ValueError(f"error: npz is old format. please re-generate {npz_path}") latents = npz["latents"] original_size = npz["original_size"].tolist()