re-organize import

This commit is contained in:
Kohya S
2023-07-23 13:33:02 +09:00
parent d1864e2430
commit 50b53e183e
3 changed files with 19 additions and 15 deletions

View File

@@ -1,7 +1,7 @@
import torch
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
@@ -486,6 +486,8 @@ def save_stable_diffusion_checkpoint(
def save_diffusers_checkpoint(
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
):
from diffusers import StableDiffusionXLPipeline
# convert U-Net
unet_sd = unet.state_dict()
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)

View File

@@ -2,13 +2,10 @@ import argparse
import gc
import math
import os
from types import SimpleNamespace
from typing import Any
from typing import Optional
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import open_clip
from diffusers import StableDiffusionXLPipeline
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
@@ -18,7 +15,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
DEFAULT_NOISE_OFFSET = 0.0357
# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
for pi in range(accelerator.state.num_processes):
@@ -33,7 +29,13 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet,
logit_scale,
ckpt_info,
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
) = _load_target_model(
args.pretrained_model_name_or_path,
args.vae,
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
)
# work on low-ram device
if args.lowram:
@@ -51,8 +53,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -68,6 +69,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
@@ -102,8 +105,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
ckpt_info = None
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -65,8 +65,8 @@ import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
from library.attention_processors import FlashAttnProcessor
from library.hypernetwork import replace_attentions_for_hypernetwork
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -1884,8 +1884,7 @@ def load_latents_from_disk(
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
npz = np.load(npz_path)
if "latents" not in npz:
print(f"error: npz is old format. please re-generate {npz_path}")
return None, None, None, None
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
latents = npz["latents"]
original_size = npz["original_size"].tolist()