Merge branch 'sdxl' into sdxl

This commit is contained in:
Isotr0py
2023-07-28 10:19:37 +08:00
committed by GitHub
9 changed files with 282 additions and 289 deletions

View File

@@ -3,7 +3,7 @@ from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
@@ -492,6 +492,8 @@ def save_stable_diffusion_checkpoint(
def save_diffusers_checkpoint(
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
):
from diffusers import StableDiffusionXLPipeline
# convert U-Net
unet_sd = unet.state_dict()
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)

View File

@@ -2,15 +2,12 @@ import argparse
import gc
import math
import os
from types import SimpleNamespace
from typing import Any
from typing import Optional
import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from tqdm import tqdm
from transformers import CLIPTokenizer
import open_clip
from diffusers import StableDiffusionXLPipeline
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
@@ -20,7 +17,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
DEFAULT_NOISE_OFFSET = 0.0357
# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
for pi in range(accelerator.state.num_processes):
@@ -35,7 +31,13 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet,
logit_scale,
ckpt_info,
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
) = _load_target_model(
args.pretrained_model_name_or_path,
args.vae,
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
)
# work on low-ram device
if args.lowram:
@@ -53,9 +55,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
# TODO: integrate full fp16/bf16 to model loading
name_or_path = args.pretrained_model_name_or_path
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -71,6 +71,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
@@ -106,8 +108,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
ckpt_info = None
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
@@ -287,54 +289,6 @@ def save_sd_model_on_epoch_end_or_stepwise(
)
# TextEncoderの出力をキャッシュする
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(args, accelerator, tokenizers, text_encoders, dataset, weight_dtype):
print("caching text encoder outputs")
tokenizer1, tokenizer2 = tokenizers
text_encoder1, text_encoder2 = text_encoders
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if weight_dtype is not None:
text_encoder1.to(dtype=weight_dtype)
text_encoder2.to(dtype=weight_dtype)
text_encoder1_cache = {}
text_encoder2_cache = {}
for batch in tqdm(dataset):
input_ids1_batch = batch["input_ids"].to(accelerator.device)
input_ids2_batch = batch["input_ids2"].to(accelerator.device)
# split batch to avoid OOM
# TODO specify batch size by args
for input_id1, input_id2 in zip(input_ids1_batch.split(1), input_ids2_batch.split(1)):
# remove input_ids already in cache
input_id1_cache_key = tuple(input_id1.flatten().tolist())
input_id2_cache_key = tuple(input_id2.flatten().tolist())
if input_id1_cache_key in text_encoder1_cache:
assert input_id2_cache_key in text_encoder2_cache
continue
with torch.no_grad():
encoder_hidden_states1, encoder_hidden_states2, pool2 = get_hidden_states(
args,
input_id1,
input_id2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
encoder_hidden_states1 = encoder_hidden_states1.detach().to("cpu").squeeze(0) # n*75+2,768
encoder_hidden_states2 = encoder_hidden_states2.detach().to("cpu").squeeze(0) # n*75+2,1280
pool2 = pool2.detach().to("cpu").squeeze(0) # 1280
text_encoder1_cache[input_id1_cache_key] = encoder_hidden_states1
text_encoder2_cache[input_id2_cache_key] = (encoder_hidden_states2, pool2)
return text_encoder1_cache, text_encoder2_cache
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"

View File

@@ -65,8 +65,8 @@ import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
from library.attention_processors import FlashAttnProcessor
from library.hypernetwork import replace_attentions_for_hypernetwork
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -1884,8 +1884,7 @@ def load_latents_from_disk(
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
npz = np.load(npz_path)
if "latents" not in npz:
print(f"error: npz is old format. please re-generate {npz_path}")
return None, None, None, None
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
latents = npz["latents"]
original_size = npz["original_size"].tolist()