mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
re-organize import
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
|
||||
@@ -486,6 +486,8 @@ def save_stable_diffusion_checkpoint(
|
||||
def save_diffusers_checkpoint(
|
||||
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
||||
):
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
# convert U-Net
|
||||
unet_sd = unet.state_dict()
|
||||
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
||||
|
||||
@@ -2,13 +2,10 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
import open_clip
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
@@ -18,7 +15,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
DEFAULT_NOISE_OFFSET = 0.0357
|
||||
|
||||
|
||||
# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
@@ -33,7 +29,13 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
) = _load_target_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
args.vae,
|
||||
model_version,
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
@@ -51,8 +53,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
|
||||
@@ -68,6 +69,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
try:
|
||||
@@ -102,8 +105,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
ckpt_info = None
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||
if vae_path is not None:
|
||||
vae = model_util.load_vae(vae_path, weight_dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
@@ -65,8 +65,8 @@ import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
from library.attention_processors import FlashAttnProcessor
|
||||
from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
# from library.attention_processors import FlashAttnProcessor
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
@@ -1884,8 +1884,7 @@ def load_latents_from_disk(
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
print(f"error: npz is old format. please re-generate {npz_path}")
|
||||
return None, None, None, None
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
|
||||
Reference in New Issue
Block a user