diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index a63bd82e..bf546a1b 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -6,8 +6,10 @@ import os from typing import List, Optional, Tuple, Union import safetensors from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) r""" @@ -55,11 +57,13 @@ ARCH_SD_V1 = "stable-diffusion-v1" ARCH_SD_V2_512 = "stable-diffusion-v2-512" ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" +ARCH_STABLE_CASCADE = "stable-cascade" ADAPTER_LORA = "lora" ADAPTER_TEXTUAL_INVERSION = "textual-inversion" IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" +IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade" IMPL_DIFFUSERS = "diffusers" PRED_TYPE_EPSILON = "epsilon" @@ -113,6 +117,7 @@ def build_metadata( merged_from: Optional[str] = None, timesteps: Optional[Tuple[int, int]] = None, clip_skip: Optional[int] = None, + stable_cascade: Optional[bool] = None, ): # if state_dict is None, hash is not calculated @@ -124,7 +129,9 @@ def build_metadata( # hash = precalculate_safetensors_hashes(state_dict) # metadata["modelspec.hash_sha256"] = hash - if sdxl: + if stable_cascade: + arch = ARCH_STABLE_CASCADE + elif sdxl: arch = ARCH_SD_XL_V1_BASE elif v2: if v_parameterization: @@ -142,9 +149,11 @@ def build_metadata( metadata["modelspec.architecture"] = arch if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: - is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion + is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion - if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: + if stable_cascade: + impl = IMPL_STABILITY_AI_STABLE_CASCADE + elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: # Stable Diffusion ckpt, TI, SDXL LoRA impl = IMPL_STABILITY_AI else: @@ -236,7 +245,7 @@ def build_metadata( # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): logger.error(f"Internal error: some metadata values are None: {metadata}") - + return metadata @@ -250,7 +259,7 @@ def get_title(metadata: dict) -> Optional[str]: def load_metadata_from_safetensors(model: str) -> dict: if not model.endswith(".safetensors"): return {} - + with safetensors.safe_open(model, framework="pt") as f: metadata = f.metadata() if metadata is None: diff --git a/library/stable_cascade.py b/library/stable_cascade.py index 7e3d9e69..f9c5c629 100644 --- a/library/stable_cascade.py +++ b/library/stable_cascade.py @@ -3,7 +3,7 @@ # https://github.com/Stability-AI/StableCascade import math -from typing import List +from typing import List, Optional import numpy as np import torch import torch.nn as nn @@ -901,15 +901,19 @@ class StageC(nn.Module): self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) -def get_clip_conditions(captions: List[str], tokenizer, text_model): +def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model): # self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None # is_eval の処理をここでやるのは微妙なので別のところでやる # is_unconditional もここでやるのは微妙なので別のところでやる # clip_image はとりあえずサポートしない - clip_tokens_unpooled = tokenizer( - captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" - ).to(text_model.device) - text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True) + if captions is not None: + clip_tokens_unpooled = tokenizer( + captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" + ).to(text_model.device) + text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True) + else: + text_encoder_output = text_model(input_ids, output_hidden_states=True) + text_embeddings = text_encoder_output.hidden_states[-1] text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) @@ -1262,4 +1266,108 @@ class CosineTNoiseCond(BaseNoiseCond): return t +# --- Loss Weighting +class BaseLossWeight: + def weight(self, logSNR): + raise NotImplementedError("this method needs to be overridden") + + def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): + clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range + if shift != 1: + logSNR = logSNR.clone() + 2 * np.log(shift) + return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) + + +# class ComposedLossWeight(BaseLossWeight): +# def __init__(self, div, mul): +# self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul +# self.div = [div] if isinstance(div, BaseLossWeight) else div + +# def weight(self, logSNR): +# prod, div = 1, 1 +# for m in self.mul: +# prod *= m.weight(logSNR) +# for d in self.div: +# div *= d.weight(logSNR) +# return prod/div + +# class ConstantLossWeight(BaseLossWeight): +# def __init__(self, v=1): +# self.v = v + +# def weight(self, logSNR): +# return torch.ones_like(logSNR) * self.v + +# class SNRLossWeight(BaseLossWeight): +# def weight(self, logSNR): +# return logSNR.exp() + + +class P2LossWeight(BaseLossWeight): + def __init__(self, k=1.0, gamma=1.0, s=1.0): + self.k, self.gamma, self.s = k, gamma, s + + def weight(self, logSNR): + return (self.k + (logSNR * self.s).exp()) ** -self.gamma + + +# class SNRPlusOneLossWeight(BaseLossWeight): +# def weight(self, logSNR): +# return logSNR.exp() + 1 + +# class MinSNRLossWeight(BaseLossWeight): +# def __init__(self, max_snr=5): +# self.max_snr = max_snr + +# def weight(self, logSNR): +# return logSNR.exp().clamp(max=self.max_snr) + +# class MinSNRPlusOneLossWeight(BaseLossWeight): +# def __init__(self, max_snr=5): +# self.max_snr = max_snr + +# def weight(self, logSNR): +# return (logSNR.exp() + 1).clamp(max=self.max_snr) + +# class TruncatedSNRLossWeight(BaseLossWeight): +# def __init__(self, min_snr=1): +# self.min_snr = min_snr + +# def weight(self, logSNR): +# return logSNR.exp().clamp(min=self.min_snr) + +# class SechLossWeight(BaseLossWeight): +# def __init__(self, div=2): +# self.div = div + +# def weight(self, logSNR): +# return 1/(logSNR/self.div).cosh() + +# class DebiasedLossWeight(BaseLossWeight): +# def weight(self, logSNR): +# return 1/logSNR.exp().sqrt() + +# class SigmoidLossWeight(BaseLossWeight): +# def __init__(self, s=1): +# self.s = s + +# def weight(self, logSNR): +# return (logSNR * self.s).sigmoid() + + +class AdaptiveLossWeight(BaseLossWeight): + def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): + self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1) + self.bucket_losses = torch.ones(buckets) + self.weight_range = weight_range + + def weight(self, logSNR): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) + return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) + + def update_buckets(self, logSNR, loss, beta=0.99): + indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() + self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta) + + # endregion gdf diff --git a/library/stable_cascade_utils.py b/library/stable_cascade_utils.py new file mode 100644 index 00000000..2406913d --- /dev/null +++ b/library/stable_cascade_utils.py @@ -0,0 +1,504 @@ +import argparse +import os +import time +from typing import List +import numpy as np + +import torch +import torchvision +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig +from accelerate import init_empty_weights + +from library import stable_cascade as sc +from library.train_util import ( + ImageInfo, + load_image, + trim_and_resize_if_required, + save_latents_to_disk, + HIGH_VRAM, + save_text_encoder_outputs_to_disk, +) +from library.sdxl_model_util import _load_state_dict_on_device +from library.device_utils import clean_memory_on_device +from library.train_util import save_sd_model_on_epoch_end_or_stepwise_common, save_sd_model_on_train_end_common +from library import sai_model_spec + + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + +EFFNET_PREPROCESS = torchvision.transforms.Compose( + [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))] +) + +TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz" +LATENTS_CACHE_SUFFIX = "_sc_latents.npz" + + +def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder: + logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}") + effnet = sc.EfficientNetEncoder() + effnet_checkpoint = load_file(effnet_checkpoint_path) + info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"]) + logger.info(info) + del effnet_checkpoint + return effnet + + +def load_tokenizer(args: argparse.Namespace): + # TODO commonize with sdxl_train_util.load_tokenizers + logger.info("prepare tokenizers") + + original_paths = [CLIP_TEXT_MODEL_NAME] + tokenizers = [] + for i, original_path in enumerate(original_paths): + tokenizer: CLIPTokenizer = None + if args.tokenizer_cache_dir: + local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_")) + if os.path.exists(local_tokenizer_path): + logger.info(f"load tokenizer from cache: {local_tokenizer_path}") + tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) + + if tokenizer is None: + tokenizer = CLIPTokenizer.from_pretrained(original_path) + + if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path): + logger.info(f"save Tokenizer to cache: {local_tokenizer_path}") + tokenizer.save_pretrained(local_tokenizer_path) + + tokenizers.append(tokenizer) + + if hasattr(args, "max_token_length") and args.max_token_length is not None: + logger.info(f"update token length: {args.max_token_length}") + + return tokenizers[0] + + +def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC: + # Generator + logger.info(f"Instantiating Stage C generator") + with init_empty_weights(): + generator_c = sc.StageC() + logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}") + stage_c_checkpoint = load_file(stage_c_checkpoint_path) + logger.info(f"Loading state dict") + info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype) + logger.info(info) + return generator_c + + +def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB: + logger.info(f"Instantiating Stage B generator") + with init_empty_weights(): + generator_b = sc.StageB() + logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}") + stage_b_checkpoint = load_file(stage_b_checkpoint_path) + logger.info(f"Loading state dict") + info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype) + logger.info(info) + return generator_b + + +def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False): + # CLIP encoders + logger.info(f"Loading CLIP text model") + if save_text_model or text_model_checkpoint_path is None: + logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}") + text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME) + + if save_text_model: + sd = text_model.state_dict() + logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}") + save_file(sd, text_model_checkpoint_path) + else: + logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}") + + # copy from sdxl_model_util.py + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + with init_empty_weights(): + text_model = CLIPTextModelWithProjection(text_model2_cfg) + + text_model_checkpoint = load_file(text_model_checkpoint_path) + info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype) + logger.info(info) + + return text_model + + +def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA: + logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}") + stage_a = sc.StageA().to(device) + stage_a_checkpoint = load_file(stage_a_checkpoint_path) + info = stage_a.load_state_dict( + stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"] + ) + logger.info(info) + return stage_a + + +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): + expected_latents_size = (reso[1] // 32, reso[0] // 32) # bucket_resoはWxHなので注意 + + if not os.path.exists(npz_path): + return False + + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? + return False + if npz["latents"].shape[1:3] != expected_latents_size: + return False + + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + + return True + + +def cache_batch_latents( + effnet: sc.EfficientNetEncoder, + cache_to_disk: bool, + image_infos: List[ImageInfo], + flip_aug: bool, + random_crop: bool, + device, + dtype, +) -> None: + r""" + requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz + optionally requires image_infos to have: image + if cache_to_disk is True, set info.latents_npz + flipped latents is also saved if flip_aug is True + if cache_to_disk is False, set info.latents + latents_flipped is also set if flip_aug is True + latents_original_size and latents_crop_ltrb are also set + """ + images = [] + for info in image_infos: + image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + image = EFFNET_PREPROCESS(image) + images.append(image) + + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + + img_tensors = torch.stack(images, dim=0) + img_tensors = img_tensors.to(device=device, dtype=dtype) + + with torch.no_grad(): + latents = effnet(img_tensors).to("cpu") + print(latents.shape) + + if flip_aug: + img_tensors = torch.flip(img_tensors, dims=[3]) + with torch.no_grad(): + flipped_latents = effnet(img_tensors).to("cpu") + else: + flipped_latents = [None] * len(latents) + + for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + # check NaN + if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): + raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") + + if cache_to_disk: + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + else: + info.latents = latent + if flip_aug: + info.latents_flipped = flipped_latent + + if not HIGH_VRAM: + clean_memory_on_device(device) + + +def cache_batch_text_encoder_outputs(image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids, dtype): + # 75 トークン越えは未対応 + input_ids = input_ids.to(text_encoders[0].device) + + with torch.no_grad(): + b_hidden_state, b_pool = sc.get_clip_conditions(None, input_ids, tokenizers[0], text_encoders[0]) + + b_hidden_state = b_hidden_state.detach().to("cpu") # b,n*75+2,768 + b_pool = b_pool.detach().to("cpu") # b,1280 + + for info, hidden_state, pool in zip(image_infos, b_hidden_state, b_pool): + if cache_to_disk: + save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, None, hidden_state, pool) + else: + info.text_encoder_outputs1 = hidden_state + info.text_encoder_pool2 = pool + + +def add_effnet_arguments(parser): + parser.add_argument( + "--effnet_checkpoint_path", + type=str, + required=True, + help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス", + ) + return parser + + +def add_text_model_arguments(parser): + parser.add_argument( + "--text_model_checkpoint_path", + type=str, + required=True, + help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス", + ) + parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path") + return parser + + +def add_stage_a_arguments(parser): + parser.add_argument( + "--stage_a_checkpoint_path", + type=str, + required=True, + help="path to Stage A checkpoint / Stage Aのチェックポイントのパス", + ) + return parser + + +def add_stage_b_arguments(parser): + parser.add_argument( + "--stage_b_checkpoint_path", + type=str, + required=True, + help="path to Stage B checkpoint / Stage Bのチェックポイントのパス", + ) + return parser + + +def add_stage_c_arguments(parser): + parser.add_argument( + "--stage_c_checkpoint_path", + type=str, + required=True, + help="path to Stage C checkpoint / Stage Cのチェックポイントのパス", + ) + return parser + + +def get_sai_model_spec(args): + timestamp = time.time() + + reso = args.resolution + + title = args.metadata_title if args.metadata_title is not None else args.output_name + + if args.min_timestep is not None or args.max_timestep is not None: + min_time_step = args.min_timestep if args.min_timestep is not None else 0 + max_time_step = args.max_timestep if args.max_timestep is not None else 1000 + timesteps = (min_time_step, max_time_step) + else: + timesteps = None + + metadata = sai_model_spec.build_metadata( + None, + False, + False, + False, + False, + False, + timestamp, + title=title, + reso=reso, + is_stable_diffusion_ckpt=False, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, # None or int + stable_cascade=True, + ) + return metadata + + +def save_stage_c_model_on_epoch_end_or_stepwise( + args: argparse.Namespace, + on_epoch_end: bool, + accelerator, + save_dtype: torch.dtype, + epoch: int, + num_train_epochs: int, + global_step: int, + stage_c, +): + def stage_c_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(args) + + state_dict = stage_c.state_dict() + if save_dtype is not None: + state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()} + + save_file(state_dict, ckpt_file, metadata=sai_metadata) + + save_sd_model_on_epoch_end_or_stepwise_common( + args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None + ) + + +def save_stage_c_model_on_end( + args: argparse.Namespace, + save_dtype: torch.dtype, + epoch: int, + global_step: int, + stage_c, +): + def stage_c_saver(ckpt_file, epoch_no, global_step): + sai_metadata = get_sai_model_spec(args) + + state_dict = stage_c.state_dict() + if save_dtype is not None: + state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()} + + save_file(state_dict, ckpt_file, metadata=sai_metadata) + + save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None) + + +def cache_latents(self, effnet, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと + logger.info("caching latents.") + + image_infos = list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + logger.info("checking cache validity...") + for info in tqdm(image_infos): + subset = self.image_to_subset[info.image_key] + + if info.latents_npz is not None: # fine tuning dataset + continue + + # check disk cache exists and size of latents + if cache_to_disk: + info.latents_npz = os.path.splitext(info.absolute_path)[0] + LATENTS_CACHE_SUFFIX + if not is_main_process: # store to info only + continue + + cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + + if cache_available: # do not add to batch + continue + + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] + + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= vae_batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only + return + + # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded + logger.info("caching latents...") + for batch in tqdm(batches, smoothing=1, total=len(batches)): + cache_batch_latents(effnet, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + + +# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる +def cache_text_encoder_outputs(self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True): + # latentsのキャッシュと同様に、ディスクへのキャッシュに対応する + # またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと + logger.info("caching text encoder outputs.") + image_infos = list(self.image_data.values()) + + logger.info("checking cache existence...") + image_infos_to_cache = [] + for info in tqdm(image_infos): + # subset = self.image_to_subset[info.image_key] + if cache_to_disk: + te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + info.text_encoder_outputs_npz = te_out_npz + + if not is_main_process: # store to info only + continue + + if os.path.exists(te_out_npz): + continue + + image_infos_to_cache.append(info) + + if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only + return + + # prepare tokenizers and text encoders + for text_encoder in text_encoders: + text_encoder.to(device) + if weight_dtype is not None: + text_encoder.to(dtype=weight_dtype) + + # create batch + batch = [] + batches = [] + for info in image_infos_to_cache: + input_ids1 = self.get_input_ids(info.caption, tokenizers[0]) + batch.append((info, input_ids1, None)) + + if len(batch) >= self.batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # iterate batches: call text encoder and cache outputs for memory or disk + logger.info("caching text encoder outputs...") + for batch in tqdm(batches): + infos, input_ids1, input_ids2 = zip(*batch) + input_ids1 = torch.stack(input_ids1, dim=0) + input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None + cache_batch_text_encoder_outputs( + infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, weight_dtype + ) diff --git a/library/train_util.py b/library/train_util.py index 60af6e89..66ca81cc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -909,7 +909,7 @@ class BaseDataset(torch.utils.data.Dataset): ) def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと + # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと logger.info("caching latents.") image_infos = list(self.image_data.values()) @@ -1325,7 +1325,7 @@ class BaseDataset(torch.utils.data.Dataset): if self.caching_mode == "text": input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) - input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) + input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) if len(self.tokenizers) > 1 else None else: input_ids1 = None input_ids2 = None @@ -2328,7 +2328,7 @@ def cache_batch_text_encoder_outputs( def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2): np.savez( npz_path, - hidden_state1=hidden_state1.cpu().float().numpy(), + hidden_state1=hidden_state1.cpu().float().numpy() if hidden_state1 is not None else None, hidden_state2=hidden_state2.cpu().float().numpy(), pool2=pool2.cpu().float().numpy(), ) @@ -2684,6 +2684,14 @@ def get_sai_model_spec( return metadata +def add_tokenizer_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む") @@ -2696,12 +2704,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): default=None, help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル", ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) + add_tokenizer_arguments(parser) def add_optimizer_arguments(parser: argparse.ArgumentParser): @@ -3150,18 +3153,22 @@ def verify_training_args(args: argparse.Namespace): print("highvram is enabled / highvramが有効です") global HIGH_VRAM HIGH_VRAM = True - - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - + if args.cache_latents_to_disk and not args.cache_latents: args.cache_latents = True logger.warning( "cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします" ) + if not hasattr(args, "v_parameterization"): + # Stable Cascade: skip following checks + return + + if args.v_parameterization and not args.v2: + logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + # noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time # # Listを使って数えてもいいけど並べてしまえ # if args.noise_offset is not None and args.multires_noise_iterations is not None: diff --git a/stable_cascade_gen_img.py b/stable_cascade_gen_img.py index 8ad31107..1ffdb03c 100644 --- a/stable_cascade_gen_img.py +++ b/stable_cascade_gen_img.py @@ -11,11 +11,11 @@ from PIL import Image from accelerate import init_empty_weights import library.stable_cascade as sc +import library.stable_cascade_utils as sc_utils import library.device_utils as device_utils +from library import train_util from library.sdxl_model_util import _load_state_dict_on_device -clip_text_model_name: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0): resolution_multiple = 42.67 @@ -45,94 +45,31 @@ def main(args): text_model_dtype = torch.float32 # EfficientNet encoder - print(f"Loading EfficientNet encoder from {args.effnet_checkpoint_path}") - effnet = sc.EfficientNetEncoder() - effnet_checkpoint = load_file(args.effnet_checkpoint_path) - info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"]) - print(info) + effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device) effnet.eval().requires_grad_(False).to(loading_device) - del effnet_checkpoint - # Generator - print(f"Instantiating Stage C generator") - with init_empty_weights(): - generator_c = sc.StageC() - print(f"Loading Stage C generator from {args.stage_c_checkpoint_path}") - stage_c_checkpoint = load_file(args.stage_c_checkpoint_path) - print(f"Loading state dict") - info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, loading_device, dtype=dtype) - print(info) + generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device) generator_c.eval().requires_grad_(False).to(loading_device) - print(f"Instantiating Stage B generator") - with init_empty_weights(): - generator_b = sc.StageB() - print(f"Loading Stage B generator from {args.stage_b_checkpoint_path}") - stage_b_checkpoint = load_file(args.stage_b_checkpoint_path) - print(f"Loading state dict") - info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, loading_device, dtype=dtype) - print(info) + generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device) generator_b.eval().requires_grad_(False).to(loading_device) # CLIP encoders print(f"Loading CLIP text model") - # TODO 完全にオフラインで動かすには tokenizer もローカルに保存できるようにする必要がある - tokenizer = AutoTokenizer.from_pretrained(clip_text_model_name) - - if args.save_text_model or args.text_model_checkpoint_path is None: - print(f"Loading CLIP text model from {clip_text_model_name}") - text_model = CLIPTextModelWithProjection.from_pretrained(clip_text_model_name) - - if args.save_text_model: - sd = text_model.state_dict() - print(f"Saving CLIP text model to {args.text_model_checkpoint_path}") - save_file(sd, args.text_model_checkpoint_path) - else: - print(f"Loading CLIP text model from {args.text_model_checkpoint_path}") - - # copy from sdxl_model_util.py - text_model2_cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1280, - intermediate_size=5120, - num_hidden_layers=32, - num_attention_heads=20, - max_position_embeddings=77, - hidden_act="gelu", - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type="clip_text_model", - projection_dim=1280, - # torch_dtype="float32", - # transformers_version="4.25.0.dev0", - ) - with init_empty_weights(): - text_model = CLIPTextModelWithProjection(text_model2_cfg) - - text_model_checkpoint = load_file(args.text_model_checkpoint_path) - info = _load_state_dict_on_device(text_model, text_model_checkpoint, text_model_device, dtype=text_model_dtype) - print(info) + tokenizer = sc_utils.load_tokenizer(args) + text_model = sc_utils.load_clip_text_model( + args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model + ) text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device) + # image_model = ( # CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device) # ) # vqGAN - print(f"Loading Stage A vqGAN from {args.stage_a_checkpoint_path}") - stage_a = sc.StageA().to(loading_device) - stage_a_checkpoint = load_file(args.stage_a_checkpoint_path) - info = stage_a.load_state_dict( - stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"] - ) - print(info) + stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device) stage_a.eval().requires_grad_(False) caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee" @@ -169,19 +106,19 @@ def main(args): # extras_b.sampling_configs["t_start"] = 1.0 # PREPARE CONDITIONS - cond_text, cond_pooled = sc.get_clip_conditions([caption], tokenizer, text_model) + cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model) cond_text = cond_text.to(device, dtype=dtype) cond_pooled = cond_pooled.to(device, dtype=dtype) - uncond_text, uncond_pooled = sc.get_clip_conditions([""], tokenizer, text_model) + uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model) uncond_text = uncond_text.to(device, dtype=dtype) uncond_pooled = uncond_pooled.to(device, dtype=dtype) - img_emb = torch.zeros(1, 768, device=device) + zero_img_emb = torch.zeros(1, 768, device=device) # 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく - conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": img_emb} - unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": img_emb} + conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb} + unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb} conditions_b = {} conditions_b.update(conditions) unconditions_b = {} @@ -249,14 +186,13 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--effnet_checkpoint_path", type=str, required=True) - parser.add_argument("--stage_a_checkpoint_path", type=str, required=True) - parser.add_argument("--stage_b_checkpoint_path", type=str, required=True) - parser.add_argument("--stage_c_checkpoint_path", type=str, required=True) - parser.add_argument( - "--text_model_checkpoint_path", type=str, required=False, default=None, help="if omitted, download from HuggingFace" - ) - parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path") + + sc_utils.add_effnet_arguments(parser) + train_util.add_tokenizer_arguments(parser) + sc_utils.add_stage_a_arguments(parser) + sc_utils.add_stage_b_arguments(parser) + sc_utils.add_stage_c_arguments(parser) + sc_utils.add_text_model_arguments(parser) parser.add_argument("--bf16", action="store_true") parser.add_argument("--fp16", action="store_true") parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先") diff --git a/stable_cascade_train_stage_c.py b/stable_cascade_train_stage_c.py new file mode 100644 index 00000000..d5b4d5c0 --- /dev/null +++ b/stable_cascade_train_stage_c.py @@ -0,0 +1,526 @@ +# training with captions + +import argparse +import math +import os +from multiprocessing import Value +from typing import List +import toml + +from tqdm import tqdm + +import torch +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from accelerate.utils import set_seed +from diffusers import DDPMScheduler + +import library.train_util as train_util +from library.sdxl_train_util import add_sdxl_training_arguments +import library.stable_cascade_utils as sc_utils +import library.stable_cascade as sc + +from library.utils import setup_logging, add_logging_arguments + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + setup_logging(args, reset=True) + + # assert ( + # not args.weighted_captions + # ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません" + + # TODO add assertions for other unsupported options + + cache_latents = args.cache_latents + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + tokenizer = sc_utils.load_tokenizer(args) + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer]) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group, True) + return + if len(train_dataset_group) == 0: + logger.error( + "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + loading_device = accelerator.device if args.lowram else "cpu" + effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device) + stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device) + text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device) + + # 学習を準備する + if cache_latents: + raise NotImplementedError("Caching latents is not supported in this version / latentのキャッシュはサポートされていません") + logger.info( + "Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`." + + " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。" + ) + # effnet.to(accelerator.device, dtype=effnet_dtype) + effnet.requires_grad_(False) + effnet.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + effnet, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + cache_func=sc_utils.cache_batch_latents, + ) + effnet.to("cpu") + clean_memory_on_device(accelerator.device) + + accelerator.wait_for_everyone() + + # 学習を準備する:モデルを適切な状態にする + if args.gradient_checkpointing: + logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.") + # stage_c.enable_gradient_checkpointing() + + text_encoder1.to(weight_dtype) + text_encoder1.requires_grad_(False) + text_encoder1.eval() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + raise NotImplementedError( + "Caching text encoder outputs is not supported in this version / text encoderの出力のキャッシュはサポートされていません" + ) + print( + f"Please make sure that the text encoder outputs are cached before training with `stable_cascade_cache_text_encoder_outputs.py`." + + " / 学習前に`stable_cascade_cache_text_encoder_outputs.py`でtext encoderの出力をキャッシュしてください。" + ) + # Text Encodes are eval and no grad + with torch.no_grad(), accelerator.autocast(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer), + (text_encoder1), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + if not cache_latents: + effnet.requires_grad_(False) + effnet.eval() + effnet.to(accelerator.device, dtype=effnet_dtype) + + stage_c.requires_grad_(True) + + training_models = [] + params_to_optimize = [] + training_models.append(stage_c) + params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate}) + + # calculate number of trainable parameters + n_params = 0 + for params in params_to_optimize: + for p in params["params"]: + n_params += p.numel() + + accelerator.print(f"number of models: {len(training_models)}") + accelerator.print(f"number of trainable parameters: {n_params}") + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print( + f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + ) + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + stage_c.to(weight_dtype) + text_encoder1.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + stage_c.to(weight_dtype) + text_encoder1.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + stage_c = accelerator.prepare(stage_c) + + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + accelerator.print("running training / 学習開始") + accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print( + f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" + ) + # accelerator.print( + # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + # ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + # 謎のクラス GDF + gdf = sc.GDF( + schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]), + input_scaler=sc.VPScaler(), + target=sc.EpsilonTarget(), + noise_cond=sc.CosineTNoiseCond(), + loss_weight=sc.AdaptiveLossWeight(), + ) + + # 以下2つの変数は、どうもデフォルトのままっぽい + # gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges']) + # gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses']) + + # noise_scheduler = DDPMScheduler( + # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + # ) + # prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + # if args.zero_terminal_snr: + # custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.wandb_run_name: + init_kwargs["wandb"] = {"name": args.wandb_run_name} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + + # # For --sample_at_first + # sdxl_train_util.sample_images( + # accelerator, + # args, + # 0, + # global_step, + # accelerator.device, + # effnet, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # stage_c, + # ) + + loss_recorder = train_util.LossRecorder() + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for m in training_models: + m.train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) + else: + with torch.no_grad(): + # latentに変換 + latents = effnet(batch["images"].to(effnet_dtype)).to(weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + with torch.no_grad(): + # Get the text embedding for conditioning + # TODO support weighted captions + input_ids1 = input_ids1.to(accelerator.device) + # unwrap_model is fine for models not wrapped by accelerator + encoder_hidden_states, pool = sc.get_clip_conditions(None, input_ids1, tokenizer, text_encoder1) + else: + encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # FORWARD PASS + with torch.no_grad(): + noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1) + + zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device) + with accelerator.autocast(): + pred = stage_c( + noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb + ) + loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3]) + loss_adjusted = (loss * loss_weight).mean() + + gdf.loss_weight.update_buckets(logSNR, loss) + + accelerator.backward(loss_adjusted) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # None, + # global_step, + # accelerator.device, + # effnet, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # stage_c, + # ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sc_utils.save_stage_c_model_on_epoch_end_or_stepwise( + args, + False, + accelerator, + save_dtype, + epoch, + num_train_epochs, + global_step, + accelerator.accelerator.unwrap_model(stage_c), + ) + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss} + train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True) + + accelerator.log(logs, step=global_step) + + loss_recorder.add(epoch=epoch, step=step, loss=current_loss) + avr_loss: float = loss_recorder.moving_average + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_recorder.moving_average} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if accelerator.is_main_process: + sc_utils.save_stage_c_model_on_epoch_end_or_stepwise( + args, True, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(stage_c) + ) + + # sdxl_train_util.sample_images( + # accelerator, + # args, + # epoch + 1, + # global_step, + # accelerator.device, + # effnet, + # [tokenizer1, tokenizer2], + # [text_encoder1, text_encoder2], + # stage_c, + # ) + + is_main_process = accelerator.is_main_process + # if is_main_process: + stage_c = accelerator.unwrap_model(stage_c) + + accelerator.end_training() + + if args.save_state: # and is_main_process: + train_util.save_state_on_train_end(args, accelerator) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + sc_utils.save_stage_c_model_on_end(args, save_dtype, epoch, global_step, stage_c) + logger.info("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + add_logging_arguments(parser) + sc_utils.add_effnet_arguments(parser) + sc_utils.add_stage_c_arguments(parser) + sc_utils.add_text_model_arguments(parser) + train_util.add_tokenizer_arguments(parser) + train_util.add_dataset_arguments(parser, True, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + add_sdxl_training_arguments(parser) # cache text encoder outputs + + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う", + ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) diff --git a/tools/stable_cascade_cache_latents.py b/tools/stable_cascade_cache_latents.py new file mode 100644 index 00000000..fbb4b480 --- /dev/null +++ b/tools/stable_cascade_cache_latents.py @@ -0,0 +1,191 @@ +# Stable Cascadeのlatentsをdiskにキャッシュする +# cache latents of Stable Cascade to disk + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import stable_cascade_utils as sc_utils +from library import config_util +from library import train_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache latents arg + assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + tokenizer = sc_utils.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + # datasetのcache_latentsを呼ばなければ、生の画像が返る + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + logger.info("load model") + effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device) + effnet.to(accelerator.device, dtype=effnet_dtype) + effnet.requires_grad_(False) + effnet.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("latents") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + b_size = len(batch["images"]) + vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size + flip_aug = batch["flip_aug"] + random_crop = batch["random_crop"] + bucket_reso = batch["bucket_reso"] + + # バッチを分割して処理する + for i in range(0, b_size, vae_batch_size): + images = batch["images"][i : i + vae_batch_size] + absolute_paths = batch["absolute_paths"][i : i + vae_batch_size] + resized_sizes = batch["resized_sizes"][i : i + vae_batch_size] + + image_infos = [] + for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.image = image + image_info.bucket_reso = bucket_reso + image_info.resized_size = resized_size + image_info.latents_npz = os.path.splitext(absolute_path)[0] + sc_utils.LATENTS_CACHE_SUFFIX + + if args.skip_existing: + if sc_utils.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): + logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") + continue + + image_infos.append(image_info) + + if len(image_infos) > 0: + sc_utils.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop, accelerator.device, effnet_dtype) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_tokenizer_arguments(parser) + sc_utils.add_effnet_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う", + ) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args) diff --git a/tools/stable_cascade_cache_text_encoder_outputs.py b/tools/stable_cascade_cache_text_encoder_outputs.py new file mode 100644 index 00000000..1745d901 --- /dev/null +++ b/tools/stable_cascade_cache_text_encoder_outputs.py @@ -0,0 +1,183 @@ +# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance + +import argparse +import math +from multiprocessing import Value +import os + +from accelerate.utils import set_seed +import torch +from tqdm import tqdm + +from library import config_util +from library import train_util +from library import sdxl_train_util +from library import stable_cascade_utils as sc_utils +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def cache_to_disk(args: argparse.Namespace) -> None: + train_util.prepare_dataset_args(args, True) + + # check cache arg + assert ( + args.cache_text_encoder_outputs_to_disk + ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります" + + use_dreambooth_method = args.in_json is None + + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + + # tokenizerを準備する:datasetを動かすために必要 + tokenizer = sc_utils.load_tokenizer(args) + tokenizers = [tokenizer] + + # データセットを準備する + if args.dataset_class is None: + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) + if args.dataset_config is not None: + logger.info(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "in_json"] + if any(getattr(args, attr) is not None for attr in ignored): + logger.warning( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + if use_dreambooth_method: + logger.info("Using DreamBooth method.") + user_config = { + "datasets": [ + { + "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( + args.train_data_dir, args.reg_data_dir + ) + } + ] + } + else: + logger.info("Training with captions.") + user_config = { + "datasets": [ + { + "subsets": [ + { + "image_dir": args.train_data_dir, + "metadata_file": args.in_json, + } + ] + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + else: + train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + + # acceleratorを準備する + logger.info("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, _ = train_util.prepare_dtype(args) + + # モデルを読み込む + logger.info("load model") + text_encoder = sc_utils.load_clip_text_model( + args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model + ) + text_encoders = [text_encoder] + for text_encoder in text_encoders: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + # dataloaderを準備する + train_dataset_group.set_caching_mode("text") + + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず + train_dataloader = accelerator.prepare(train_dataloader) + + # データ取得のためのループ + for batch in tqdm(train_dataloader): + absolute_paths = batch["absolute_paths"] + input_ids1_list = batch["input_ids1_list"] + + image_infos = [] + for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list): + image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path) + image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX + image_info + + if args.skip_existing: + if os.path.exists(image_info.text_encoder_outputs_npz): + logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.") + continue + + image_info.input_ids1 = input_ids1 + image_infos.append(image_info) + + if len(image_infos) > 0: + b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos]) + sc_utils.cache_batch_text_encoder_outputs( + image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, weight_dtype + ) + + accelerator.wait_for_everyone() + accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_tokenizer_arguments(parser) + sc_utils.add_text_model_arguments(parser) + train_util.add_training_arguments(parser, True) + train_util.add_dataset_arguments(parser, True, True, True) + config_util.add_config_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)", + ) + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + cache_to_disk(args)