mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'sdxl' into sdxl
This commit is contained in:
@@ -3,7 +3,7 @@ from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
|
||||
@@ -492,6 +492,8 @@ def save_stable_diffusion_checkpoint(
|
||||
def save_diffusers_checkpoint(
|
||||
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
|
||||
):
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
# convert U-Net
|
||||
unet_sd = unet.state_dict()
|
||||
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
|
||||
|
||||
@@ -2,15 +2,12 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
import open_clip
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
@@ -20,7 +17,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
DEFAULT_NOISE_OFFSET = 0.0357
|
||||
|
||||
|
||||
# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
@@ -35,7 +31,13 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
) = _load_target_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
args.vae,
|
||||
model_version,
|
||||
weight_dtype,
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
@@ -53,9 +55,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
|
||||
# TODO: integrate full fp16/bf16 to model loading
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
|
||||
@@ -71,6 +71,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
try:
|
||||
@@ -106,8 +108,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
|
||||
ckpt_info = None
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
vae = model_util.load_vae(args.vae, weight_dtype)
|
||||
if vae_path is not None:
|
||||
vae = model_util.load_vae(vae_path, weight_dtype)
|
||||
print("additional VAE loaded")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
@@ -287,54 +289,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
)
|
||||
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
|
||||
print("caching text encoder outputs")
|
||||
|
||||
tokenizer1, tokenizer2 = tokenizers
|
||||
text_encoder1, text_encoder2 = text_encoders
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
if weight_dtype is not None:
|
||||
text_encoder1.to(dtype=weight_dtype)
|
||||
text_encoder2.to(dtype=weight_dtype)
|
||||
|
||||
text_encoder1_cache = {}
|
||||
text_encoder2_cache = {}
|
||||
for batch in tqdm(dataset):
|
||||
input_ids1_batch = batch["input_ids"].to(accelerator.device)
|
||||
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
|
||||
|
||||
# split batch to avoid OOM
|
||||
# TODO specify batch size by args
|
||||
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
|
||||
# remove input_ids already in cache
|
||||
input_id1_cache_key = tuple(input_id1.flatten().tolist())
|
||||
input_id2_cache_key = tuple(input_id2.flatten().tolist())
|
||||
if input_id1_cache_key in text_encoder1_cache:
|
||||
assert input_id2_cache_key in text_encoder2_cache
|
||||
continue
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
|
||||
args,
|
||||
input_id1,
|
||||
input_id2,
|
||||
tokenizer1,
|
||||
tokenizer2,
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
|
||||
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
|
||||
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
|
||||
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
|
||||
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
|
||||
return text_encoder1_cache, text_encoder2_cache
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
|
||||
@@ -65,8 +65,8 @@ import safetensors.torch
|
||||
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
from library.attention_processors import FlashAttnProcessor
|
||||
from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
# from library.attention_processors import FlashAttnProcessor
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
@@ -1884,8 +1884,7 @@ def load_latents_from_disk(
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
print(f"error: npz is old format. please re-generate {npz_path}")
|
||||
return None, None, None, None
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
|
||||
Reference in New Issue
Block a user