mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 00:17:18 +00:00
505 lines
18 KiB
Python
505 lines
18 KiB
Python
import argparse
|
|
import os
|
|
import time
|
|
from typing import List
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torchvision
|
|
from safetensors.torch import load_file, save_file
|
|
from tqdm import tqdm
|
|
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
|
from accelerate import init_empty_weights
|
|
|
|
from library import stable_cascade as sc
|
|
from library.train_util import (
|
|
ImageInfo,
|
|
load_image,
|
|
trim_and_resize_if_required,
|
|
save_latents_to_disk,
|
|
HIGH_VRAM,
|
|
save_text_encoder_outputs_to_disk,
|
|
)
|
|
from library.sdxl_model_util import _load_state_dict_on_device
|
|
from library.device_utils import clean_memory_on_device
|
|
from library.train_util import save_sd_model_on_epoch_end_or_stepwise_common, save_sd_model_on_train_end_common
|
|
from library import sai_model_spec
|
|
|
|
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|
|
|
EFFNET_PREPROCESS = torchvision.transforms.Compose(
|
|
[torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]
|
|
)
|
|
|
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
|
|
LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
|
|
|
|
|
|
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
|
|
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
|
|
effnet = sc.EfficientNetEncoder()
|
|
effnet_checkpoint = load_file(effnet_checkpoint_path)
|
|
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
|
|
logger.info(info)
|
|
del effnet_checkpoint
|
|
return effnet
|
|
|
|
|
|
def load_tokenizer(args: argparse.Namespace):
|
|
# TODO commonize with sdxl_train_util.load_tokenizers
|
|
logger.info("prepare tokenizers")
|
|
|
|
original_paths = [CLIP_TEXT_MODEL_NAME]
|
|
tokenizers = []
|
|
for i, original_path in enumerate(original_paths):
|
|
tokenizer: CLIPTokenizer = None
|
|
if args.tokenizer_cache_dir:
|
|
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
|
if os.path.exists(local_tokenizer_path):
|
|
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
|
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
|
|
|
if tokenizer is None:
|
|
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
|
|
|
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
|
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
|
tokenizer.save_pretrained(local_tokenizer_path)
|
|
|
|
tokenizers.append(tokenizer)
|
|
|
|
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
|
logger.info(f"update token length: {args.max_token_length}")
|
|
|
|
return tokenizers[0]
|
|
|
|
|
|
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
|
|
# Generator
|
|
logger.info(f"Instantiating Stage C generator")
|
|
with init_empty_weights():
|
|
generator_c = sc.StageC()
|
|
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
|
|
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
|
|
logger.info(f"Loading state dict")
|
|
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
|
|
logger.info(info)
|
|
return generator_c
|
|
|
|
|
|
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
|
|
logger.info(f"Instantiating Stage B generator")
|
|
with init_empty_weights():
|
|
generator_b = sc.StageB()
|
|
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
|
|
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
|
|
logger.info(f"Loading state dict")
|
|
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
|
|
logger.info(info)
|
|
return generator_b
|
|
|
|
|
|
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
|
|
# CLIP encoders
|
|
logger.info(f"Loading CLIP text model")
|
|
if save_text_model or text_model_checkpoint_path is None:
|
|
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
|
|
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
|
|
|
|
if save_text_model:
|
|
sd = text_model.state_dict()
|
|
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
|
|
save_file(sd, text_model_checkpoint_path)
|
|
else:
|
|
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
|
|
|
|
# copy from sdxl_model_util.py
|
|
text_model2_cfg = CLIPTextConfig(
|
|
vocab_size=49408,
|
|
hidden_size=1280,
|
|
intermediate_size=5120,
|
|
num_hidden_layers=32,
|
|
num_attention_heads=20,
|
|
max_position_embeddings=77,
|
|
hidden_act="gelu",
|
|
layer_norm_eps=1e-05,
|
|
dropout=0.0,
|
|
attention_dropout=0.0,
|
|
initializer_range=0.02,
|
|
initializer_factor=1.0,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
model_type="clip_text_model",
|
|
projection_dim=1280,
|
|
# torch_dtype="float32",
|
|
# transformers_version="4.25.0.dev0",
|
|
)
|
|
with init_empty_weights():
|
|
text_model = CLIPTextModelWithProjection(text_model2_cfg)
|
|
|
|
text_model_checkpoint = load_file(text_model_checkpoint_path)
|
|
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
|
|
logger.info(info)
|
|
|
|
return text_model
|
|
|
|
|
|
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
|
|
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
|
|
stage_a = sc.StageA().to(device)
|
|
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
|
|
info = stage_a.load_state_dict(
|
|
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
|
|
)
|
|
logger.info(info)
|
|
return stage_a
|
|
|
|
|
|
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
|
expected_latents_size = (reso[1] // 32, reso[0] // 32) # bucket_resoはWxHなので注意
|
|
|
|
if not os.path.exists(npz_path):
|
|
return False
|
|
|
|
npz = np.load(npz_path)
|
|
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
|
|
return False
|
|
if npz["latents"].shape[1:3] != expected_latents_size:
|
|
return False
|
|
|
|
if flip_aug:
|
|
if "latents_flipped" not in npz:
|
|
return False
|
|
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def cache_batch_latents(
|
|
effnet: sc.EfficientNetEncoder,
|
|
cache_to_disk: bool,
|
|
image_infos: List[ImageInfo],
|
|
flip_aug: bool,
|
|
random_crop: bool,
|
|
device,
|
|
dtype,
|
|
) -> None:
|
|
r"""
|
|
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
|
optionally requires image_infos to have: image
|
|
if cache_to_disk is True, set info.latents_npz
|
|
flipped latents is also saved if flip_aug is True
|
|
if cache_to_disk is False, set info.latents
|
|
latents_flipped is also set if flip_aug is True
|
|
latents_original_size and latents_crop_ltrb are also set
|
|
"""
|
|
images = []
|
|
for info in image_infos:
|
|
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
|
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
|
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
|
image = EFFNET_PREPROCESS(image)
|
|
images.append(image)
|
|
|
|
info.latents_original_size = original_size
|
|
info.latents_crop_ltrb = crop_ltrb
|
|
|
|
img_tensors = torch.stack(images, dim=0)
|
|
img_tensors = img_tensors.to(device=device, dtype=dtype)
|
|
|
|
with torch.no_grad():
|
|
latents = effnet(img_tensors).to("cpu")
|
|
print(latents.shape)
|
|
|
|
if flip_aug:
|
|
img_tensors = torch.flip(img_tensors, dims=[3])
|
|
with torch.no_grad():
|
|
flipped_latents = effnet(img_tensors).to("cpu")
|
|
else:
|
|
flipped_latents = [None] * len(latents)
|
|
|
|
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
|
|
# check NaN
|
|
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
|
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
|
|
|
if cache_to_disk:
|
|
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
|
|
else:
|
|
info.latents = latent
|
|
if flip_aug:
|
|
info.latents_flipped = flipped_latent
|
|
|
|
if not HIGH_VRAM:
|
|
clean_memory_on_device(device)
|
|
|
|
|
|
def cache_batch_text_encoder_outputs(image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids, dtype):
|
|
# 75 トークン越えは未対応
|
|
input_ids = input_ids.to(text_encoders[0].device)
|
|
|
|
with torch.no_grad():
|
|
b_hidden_state, b_pool = sc.get_clip_conditions(None, input_ids, tokenizers[0], text_encoders[0])
|
|
|
|
b_hidden_state = b_hidden_state.detach().to("cpu") # b,n*75+2,768
|
|
b_pool = b_pool.detach().to("cpu") # b,1280
|
|
|
|
for info, hidden_state, pool in zip(image_infos, b_hidden_state, b_pool):
|
|
if cache_to_disk:
|
|
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, None, hidden_state, pool)
|
|
else:
|
|
info.text_encoder_outputs1 = hidden_state
|
|
info.text_encoder_pool2 = pool
|
|
|
|
|
|
def add_effnet_arguments(parser):
|
|
parser.add_argument(
|
|
"--effnet_checkpoint_path",
|
|
type=str,
|
|
required=True,
|
|
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
|
|
)
|
|
return parser
|
|
|
|
|
|
def add_text_model_arguments(parser):
|
|
parser.add_argument(
|
|
"--text_model_checkpoint_path",
|
|
type=str,
|
|
required=True,
|
|
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
|
|
)
|
|
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
|
|
return parser
|
|
|
|
|
|
def add_stage_a_arguments(parser):
|
|
parser.add_argument(
|
|
"--stage_a_checkpoint_path",
|
|
type=str,
|
|
required=True,
|
|
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
|
|
)
|
|
return parser
|
|
|
|
|
|
def add_stage_b_arguments(parser):
|
|
parser.add_argument(
|
|
"--stage_b_checkpoint_path",
|
|
type=str,
|
|
required=True,
|
|
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
|
|
)
|
|
return parser
|
|
|
|
|
|
def add_stage_c_arguments(parser):
|
|
parser.add_argument(
|
|
"--stage_c_checkpoint_path",
|
|
type=str,
|
|
required=True,
|
|
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
|
|
)
|
|
return parser
|
|
|
|
|
|
def get_sai_model_spec(args):
|
|
timestamp = time.time()
|
|
|
|
reso = args.resolution
|
|
|
|
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
|
|
|
if args.min_timestep is not None or args.max_timestep is not None:
|
|
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
|
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
|
timesteps = (min_time_step, max_time_step)
|
|
else:
|
|
timesteps = None
|
|
|
|
metadata = sai_model_spec.build_metadata(
|
|
None,
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
timestamp,
|
|
title=title,
|
|
reso=reso,
|
|
is_stable_diffusion_ckpt=False,
|
|
author=args.metadata_author,
|
|
description=args.metadata_description,
|
|
license=args.metadata_license,
|
|
tags=args.metadata_tags,
|
|
timesteps=timesteps,
|
|
clip_skip=args.clip_skip, # None or int
|
|
stable_cascade=True,
|
|
)
|
|
return metadata
|
|
|
|
|
|
def save_stage_c_model_on_epoch_end_or_stepwise(
|
|
args: argparse.Namespace,
|
|
on_epoch_end: bool,
|
|
accelerator,
|
|
save_dtype: torch.dtype,
|
|
epoch: int,
|
|
num_train_epochs: int,
|
|
global_step: int,
|
|
stage_c,
|
|
):
|
|
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
|
sai_metadata = get_sai_model_spec(args)
|
|
|
|
state_dict = stage_c.state_dict()
|
|
if save_dtype is not None:
|
|
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
|
|
|
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
|
|
|
save_sd_model_on_epoch_end_or_stepwise_common(
|
|
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
|
|
)
|
|
|
|
|
|
def save_stage_c_model_on_end(
|
|
args: argparse.Namespace,
|
|
save_dtype: torch.dtype,
|
|
epoch: int,
|
|
global_step: int,
|
|
stage_c,
|
|
):
|
|
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
|
sai_metadata = get_sai_model_spec(args)
|
|
|
|
state_dict = stage_c.state_dict()
|
|
if save_dtype is not None:
|
|
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
|
|
|
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
|
|
|
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
|
|
|
|
|
|
def cache_latents(self, effnet, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
|
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
|
logger.info("caching latents.")
|
|
|
|
image_infos = list(self.image_data.values())
|
|
|
|
# sort by resolution
|
|
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
|
|
|
# split by resolution
|
|
batches = []
|
|
batch = []
|
|
logger.info("checking cache validity...")
|
|
for info in tqdm(image_infos):
|
|
subset = self.image_to_subset[info.image_key]
|
|
|
|
if info.latents_npz is not None: # fine tuning dataset
|
|
continue
|
|
|
|
# check disk cache exists and size of latents
|
|
if cache_to_disk:
|
|
info.latents_npz = os.path.splitext(info.absolute_path)[0] + LATENTS_CACHE_SUFFIX
|
|
if not is_main_process: # store to info only
|
|
continue
|
|
|
|
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
|
|
|
|
if cache_available: # do not add to batch
|
|
continue
|
|
|
|
# if last member of batch has different resolution, flush the batch
|
|
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
|
batches.append(batch)
|
|
batch = []
|
|
|
|
batch.append(info)
|
|
|
|
# if number of data in batch is enough, flush the batch
|
|
if len(batch) >= vae_batch_size:
|
|
batches.append(batch)
|
|
batch = []
|
|
|
|
if len(batch) > 0:
|
|
batches.append(batch)
|
|
|
|
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
|
return
|
|
|
|
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
|
logger.info("caching latents...")
|
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
|
cache_batch_latents(effnet, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
|
|
|
|
|
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
|
def cache_text_encoder_outputs(self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True):
|
|
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
|
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
|
logger.info("caching text encoder outputs.")
|
|
image_infos = list(self.image_data.values())
|
|
|
|
logger.info("checking cache existence...")
|
|
image_infos_to_cache = []
|
|
for info in tqdm(image_infos):
|
|
# subset = self.image_to_subset[info.image_key]
|
|
if cache_to_disk:
|
|
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
|
info.text_encoder_outputs_npz = te_out_npz
|
|
|
|
if not is_main_process: # store to info only
|
|
continue
|
|
|
|
if os.path.exists(te_out_npz):
|
|
continue
|
|
|
|
image_infos_to_cache.append(info)
|
|
|
|
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
|
return
|
|
|
|
# prepare tokenizers and text encoders
|
|
for text_encoder in text_encoders:
|
|
text_encoder.to(device)
|
|
if weight_dtype is not None:
|
|
text_encoder.to(dtype=weight_dtype)
|
|
|
|
# create batch
|
|
batch = []
|
|
batches = []
|
|
for info in image_infos_to_cache:
|
|
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
|
batch.append((info, input_ids1, None))
|
|
|
|
if len(batch) >= self.batch_size:
|
|
batches.append(batch)
|
|
batch = []
|
|
|
|
if len(batch) > 0:
|
|
batches.append(batch)
|
|
|
|
# iterate batches: call text encoder and cache outputs for memory or disk
|
|
logger.info("caching text encoder outputs...")
|
|
for batch in tqdm(batches):
|
|
infos, input_ids1, input_ids2 = zip(*batch)
|
|
input_ids1 = torch.stack(input_ids1, dim=0)
|
|
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
|
|
cache_batch_text_encoder_outputs(
|
|
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, weight_dtype
|
|
)
|