update stable cascade stage C training #1119

This commit is contained in:
Kohya S
2024-02-18 17:54:21 +09:00
parent 856df07f49
commit 4b5784eb44
9 changed files with 869 additions and 454 deletions

View File

@@ -1,3 +1,88 @@
# Trainng Stable Cascade Stage C
This is an experimental feature. There may be bugs.
## Usage
Training is run with `stable_cascade_train_stage_c.py`.
The main options are the same as `sdxl_train.py`. The following options have been added.
- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights.
- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights.
- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used.
- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`.
- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training.
- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight.
The learning rate is set to 1e-4 in the official settings.
The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights.
Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images.
Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary).
Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight.
`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is.
A scale of about 4 is suitable for sample image generation.
Since the official settings use `bf16` for training, training with `fp16` may be unstable.
The code for training the Text Encoder is also written, but it is untested.
### About the dataset for fine tuning
If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance.
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
# Stable Cascade Stage C の学習
実験的機能です。不具合があるかもしれません。
## 使い方
学習は `stable_cascade_train_stage_c.py` で行います。
主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。
- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。
- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。
- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。
- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。
- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。
- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。
学習率は、公式の設定では 1e-4 のようです。
初回は `--text_model_checkpoint_path``--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。
学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。
SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EffcientNetEncoder に適用されますmixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。
latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EffcientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。
メモリ消費を抑えるための `--gradient_checkpointing``--full_bf16``--full_fp16`(未テスト)はそのまま使用できます。
サンプル画像生成時の Scale には 4 程度が適しているようです。
公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。
Text Encoder 学習のコードも書いてありますが、未テストです。
### fine tuning方式のデータセットについて
SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.

View File

@@ -11,15 +11,19 @@ import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
import library.model_util as model_util
import library.stable_cascade_utils as sc_utils
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
@@ -42,7 +46,7 @@ def collate_fn_remove_corrupted(batch):
return batch
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
@@ -50,10 +54,11 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
base_name = image_key
relative_path = ""
ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name) + ".npz"
return os.path.join(data_dir, relative_path, base_name) + ext
else:
return os.path.join(data_dir, base_name) + ".npz"
return os.path.join(data_dir, base_name) + ext
def main(args):
@@ -83,13 +88,20 @@ def main(args):
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
if not args.stable_cascade:
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
divisor = 8
else:
vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE)
divisor = 32
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -154,6 +166,10 @@ def main(args):
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
# 追加情報を記録
metadata[image_key]["original_size"] = (image.width, image.height)
metadata[image_key]["train_resized_size"] = resized_size
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert (
@@ -168,9 +184,9 @@ def main(args):
), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshape等を確認して同じならskipする
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade)
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor):
continue
# バッチへ追加
@@ -208,7 +224,14 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument(
"--stable_cascade",
action="store_true",
help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する",
)
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
@@ -231,10 +254,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument(
"--full_path",
@@ -242,7 +271,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
)
parser.add_argument(
"--skip_existing",

View File

@@ -3,10 +3,12 @@
# https://github.com/Stability-AI/StableCascade
import math
from types import SimpleNamespace
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torchvision
@@ -125,6 +127,23 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x):
return self.mapper(self.backbone(x))
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def encode(self, x):
"""
VAE と同じように使えるようにするためのメソッド。正しくはちゃんと呼び出し側で分けるべきだが、暫定的な対応。
The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
"""
# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
x = self(x)
return SimpleNamespace(latent_dist=SimpleNamespace(sample=lambda: x))
# なんかわりと乱暴な実装(;'∀')
# 一から学習することもないだろうから、無効化しておく
@@ -136,6 +155,7 @@ class EfficientNetEncoder(nn.Module):
# class Conv2d(torch.nn.Conv2d):
# def reset_parameters(self):
# return None
from torch.nn import Conv2d
from torch.nn import Linear
@@ -187,7 +207,12 @@ class ResBlock(nn.Module):
Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
)
def forward(self, x, x_skip=None):
self.gradient_checkpointing = False
def set_gradient_checkpointing(self, value):
self.gradient_checkpointing = value
def forward_body(self, x, x_skip=None):
x_res = x
x = self.norm(self.depthwise(x))
if x_skip is not None:
@@ -195,6 +220,22 @@ class ResBlock(nn.Module):
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x + x_res
def forward(self, x, x_skip=None):
if self.training and self.gradient_checkpointing:
# logger.info("ResnetBlock2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
return func(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, x_skip)
else:
x = self.forward_body(x, x_skip)
return x
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
@@ -204,11 +245,32 @@ class AttnBlock(nn.Module):
self.attention = Attention2D(c, nhead, dropout)
self.kv_mapper = nn.Sequential(nn.SiLU(), Linear(c_cond, c))
def forward(self, x, kv):
self.gradient_checkpointing = False
def set_gradient_checkpointing(self, value):
self.gradient_checkpointing = value
def forward_body(self, x, kv):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
return x
def forward(self, x, kv):
if self.training and self.gradient_checkpointing:
# logger.info("AttnBlock: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
return func(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, kv)
else:
x = self.forward_body(x, kv)
return x
class FeedForwardBlock(nn.Module):
def __init__(self, c, dropout=0.0):
@@ -218,10 +280,31 @@ class FeedForwardBlock(nn.Module):
Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
)
def forward(self, x):
self.gradient_checkpointing = False
def set_gradient_checkpointing(self, value):
self.gradient_checkpointing = value
def forward_body(self, x):
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
def forward(self, x):
if self.training and self.gradient_checkpointing:
# logger.info("FeedForwardBlock: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
return func(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
else:
x = self.forward_body(x)
return x
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep, conds=["sca"]):
@@ -250,9 +333,38 @@ class UpDownBlock2d(nn.Module):
mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
self.mode = mode
self.gradient_checkpointing = False
def set_gradient_checkpointing(self, value):
self.gradient_checkpointing = value
def forward_body(self, x):
org_dtype = x.dtype
for i, block in enumerate(self.blocks):
# 公式の実装では、常に float で計算しているが、すこしでもメモリを節約するために bfloat16 + Upsample のみ float に変換する
# In the official implementation, it always calculates in float, but for the sake of saving memory, it converts to float only for bfloat16 + Upsample
if x.dtype == torch.bfloat16 and (self.mode == "up" and i == 0 or self.mode != "up" and i == 1):
x = x.float()
x = block(x)
x = x.to(org_dtype)
return x
def forward(self, x):
for block in self.blocks:
x = block(x.float())
if self.training and self.gradient_checkpointing:
# logger.info("UpDownBlock2d: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
return func(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
else:
x = self.forward_body(x)
return x
@@ -790,6 +902,12 @@ class StageC(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def set_gradient_checkpointing(self, value):
for block in self.down_blocks + self.up_blocks:
for layer in block:
if hasattr(layer, "set_gradient_checkpointing"):
layer.set_gradient_checkpointing(value)
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
@@ -900,8 +1018,62 @@ class StageC(nn.Module):
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):
return self.blocks(x)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model):
# deprecated
# self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
# is_eval の処理をここでやるのは微妙なので別のところでやる
# is_unconditional もここでやるのは微妙なので別のところでやる
@@ -921,39 +1093,6 @@ def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, tex
# return {"clip_text": text_embeddings, "clip_text_pooled": text_pooled_embeddings} # , "clip_img": image_embeddings}
def get_stage_c_conditions(
batch: dict, effnet, effnet_preprocess, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
):
images = batch.get("images", None)
if images is not None:
images = images.to(self.device)
if is_eval and not is_unconditional:
effnet_embeddings = effnet(effnet_preprocess(images))
else:
if is_eval:
effnet_factor = 1
else:
effnet_factor = np.random.uniform(0.5, 1) # f64 to f32
effnet_height, effnet_width = int(((images.size(-2) * effnet_factor) // 32) * 32), int(
((images.size(-1) * effnet_factor) // 32) * 32
)
effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height // 32, effnet_width // 32, device=self.device)
if not is_eval:
effnet_images = torchvision.transforms.functional.resize(
images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
rand_idx = np.random.rand(len(images)) <= 0.9
if any(rand_idx):
effnet_embeddings[rand_idx] = effnet(effnet_preprocess(effnet_images[rand_idx]))
else:
effnet_embeddings = None
return effnet_embeddings
# {"effnet": effnet_embeddings, "clip": conditions["clip_text_pooled"]}
# region gdf

View File

@@ -1,28 +1,30 @@
import argparse
import json
import math
import os
import time
from typing import List
import numpy as np
import toml
import torch
import torchvision
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from accelerate import init_empty_weights
from accelerate import init_empty_weights, Accelerator, PartialState
from PIL import Image
from library import stable_cascade as sc
from library.train_util import (
ImageInfo,
load_image,
trim_and_resize_if_required,
save_latents_to_disk,
HIGH_VRAM,
save_text_encoder_outputs_to_disk,
)
from library.sdxl_model_util import _load_state_dict_on_device
from library.device_utils import clean_memory_on_device
from library.train_util import save_sd_model_on_epoch_end_or_stepwise_common, save_sd_model_on_train_end_common
from library.train_util import (
save_sd_model_on_epoch_end_or_stepwise_common,
save_sd_model_on_train_end_common,
line_to_prompt_dict,
get_hidden_states_stable_cascade,
)
from library import sai_model_spec
@@ -41,7 +43,22 @@ EFFNET_PREPROCESS = torchvision.transforms.Compose(
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
resolution_multiple = 42.67
latent_height = math.ceil(height / compression_factor_b)
latent_width = math.ceil(width / compression_factor_b)
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
latent_height = math.ceil(height / compression_factor_a)
latent_width = math.ceil(width / compression_factor_a)
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
return stage_c_latent_shape, stage_b_latent_shape
# region load and save
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
@@ -165,153 +182,15 @@ def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.
return stage_a
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 32, reso[0] // 32) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
return False
npz = np.load(npz_path)
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
return False
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
return True
def cache_batch_latents(
effnet: sc.EfficientNetEncoder,
cache_to_disk: bool,
image_infos: List[ImageInfo],
flip_aug: bool,
random_crop: bool,
device,
dtype,
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
optionally requires image_infos to have: image
if cache_to_disk is True, set info.latents_npz
flipped latents is also saved if flip_aug is True
if cache_to_disk is False, set info.latents
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_ltrb are also set
"""
images = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image = EFFNET_PREPROCESS(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=device, dtype=dtype)
with torch.no_grad():
latents = effnet(img_tensors).to("cpu")
print(latents.shape)
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = effnet(img_tensors).to("cpu")
else:
flipped_latents = [None] * len(latents)
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
if not HIGH_VRAM:
clean_memory_on_device(device)
def cache_batch_text_encoder_outputs(image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids, dtype):
# 75 トークン越えは未対応
input_ids = input_ids.to(text_encoders[0].device)
with torch.no_grad():
b_hidden_state, b_pool = sc.get_clip_conditions(None, input_ids, tokenizers[0], text_encoders[0])
b_hidden_state = b_hidden_state.detach().to("cpu") # b,n*75+2,768
b_pool = b_pool.detach().to("cpu") # b,1280
for info, hidden_state, pool in zip(image_infos, b_hidden_state, b_pool):
if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, None, hidden_state, pool)
else:
info.text_encoder_outputs1 = hidden_state
info.text_encoder_pool2 = pool
def add_effnet_arguments(parser):
parser.add_argument(
"--effnet_checkpoint_path",
type=str,
required=True,
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer:
logger.info(f"Loading Previewer from {previewer_checkpoint_path}")
previewer = sc.Previewer().to(device)
previewer_checkpoint = load_file(previewer_checkpoint_path)
info = previewer.load_state_dict(
previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"]
)
return parser
def add_text_model_arguments(parser):
parser.add_argument(
"--text_model_checkpoint_path",
type=str,
required=True,
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
)
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
return parser
def add_stage_a_arguments(parser):
parser.add_argument(
"--stage_a_checkpoint_path",
type=str,
required=True,
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
)
return parser
def add_stage_b_arguments(parser):
parser.add_argument(
"--stage_b_checkpoint_path",
type=str,
required=True,
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
)
return parser
def add_stage_c_arguments(parser):
parser.add_argument(
"--stage_c_checkpoint_path",
type=str,
required=True,
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
)
return parser
logger.info(info)
return previewer
def get_sai_model_spec(args):
@@ -353,7 +232,7 @@ def get_sai_model_spec(args):
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items}
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
save_file(state_dict, ckpt_file, metadata=sai_metadata)
@@ -403,112 +282,334 @@ def save_stage_c_model_on_end(
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
def cache_latents(self, effnet, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
# endregion
image_infos = list(self.image_data.values())
# region sample generation
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
previewer,
tokenizer,
text_encoder,
stage_c,
gdf,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + LATENTS_CACHE_SUFFIX
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
if cache_available: # do not add to batch
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(effnet, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
stage_c = accelerator.unwrap_model(stage_c)
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
def cache_text_encoder_outputs(self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True):
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
image_infos = list(self.image_data.values())
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
tokenizer,
text_model,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 1024)
height = prompt_dict.get("height", 1024)
scale = prompt_dict.get("scale", 4)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
logger.info("checking cache existence...")
image_infos_to_cache = []
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
info.text_encoder_outputs_npz = te_out_npz
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if not is_main_process: # store to info only
continue
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if os.path.exists(te_out_npz):
continue
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
image_infos_to_cache.append(info)
negative_prompt = "" if negative_prompt is None else negative_prompt
cfg = scale
timesteps = sample_steps
shift = 2
t_start = 1.0
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1)
# prepare tokenizers and text encoders
for text_encoder in text_encoders:
text_encoder.to(device)
if weight_dtype is not None:
text_encoder.to(dtype=weight_dtype)
# PREPARE CONDITIONS
input_ids = tokenizer(
[prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
# create batch
batch = []
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
batch.append((info, input_ids1, None))
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
if len(batch) >= self.batch_size:
batches.append(batch)
batch = []
device = accelerator.device
dtype = stage_c.dtype
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
if len(batch) > 0:
batches.append(batch)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
# iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...")
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, weight_dtype
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf.sample(
stage_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype)
image = previewer(sampled_c)[0]
image = torch.clamp(image, 0, 1)
image = image.cpu().numpy().transpose(1, 2, 0)
image = image * 255
image = image.astype(np.uint8)
image = Image.fromarray(image)
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion
def add_effnet_arguments(parser):
parser.add_argument(
"--effnet_checkpoint_path",
type=str,
required=True,
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
)
return parser
def add_text_model_arguments(parser):
parser.add_argument(
"--text_model_checkpoint_path",
type=str,
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
)
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
return parser
def add_stage_a_arguments(parser):
parser.add_argument(
"--stage_a_checkpoint_path",
type=str,
required=True,
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
)
return parser
def add_stage_b_arguments(parser):
parser.add_argument(
"--stage_b_checkpoint_path",
type=str,
required=True,
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
)
return parser
def add_stage_c_arguments(parser):
parser.add_argument(
"--stage_c_checkpoint_path",
type=str,
required=True,
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
)
return parser
def add_previewer_arguments(parser):
parser.add_argument(
"--previewer_checkpoint_path",
type=str,
required=False,
help="path to previewer checkpoint / previewerのチェックポイントのパス",
)
return parser
def add_training_arguments(parser):
parser.add_argument(
"--adaptive_loss_weight",
action="store_true",
help="if specified, use adaptive loss weight. if not, use P2 loss weight"
+ " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する",
)

View File

@@ -133,6 +133,7 @@ IMAGE_TRANSFORMS = transforms.Compose(
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
STABLE_CASCADE_LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
class ImageInfo:
@@ -856,7 +857,7 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List(BucketBatchIndex) = []
self.buckets_indices: List[BucketBatchIndex] = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
@@ -910,8 +911,8 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
def cache_latents(self, vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
image_infos = list(self.image_data.values())
@@ -931,11 +932,11 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug, divisor)
if cache_available: # do not add to batch
continue
@@ -967,9 +968,13 @@ class BaseDataset(torch.utils.data.Dataset):
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
):
assert len(tokenizers) == 2, "only support SDXL"
"""
最後の Text Encoder の pool がキャッシュされる。
The last Text Encoder's pool is cached.
"""
# assert len(tokenizers) == 2, "only support SDXL"
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
@@ -981,7 +986,7 @@ class BaseDataset(torch.utils.data.Dataset):
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
te_out_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process: # store to info only
@@ -1006,7 +1011,7 @@ class BaseDataset(torch.utils.data.Dataset):
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) if len(tokenizers) > 1 else None
batch.append((info, input_ids1, input_ids2))
if len(batch) >= self.batch_size:
@@ -1021,7 +1026,7 @@ class BaseDataset(torch.utils.data.Dataset):
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
@@ -1270,7 +1275,9 @@ class BaseDataset(torch.utils.data.Dataset):
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
example["text_encoder_outputs2_list"] = (
torch.stack(text_encoder_outputs2_list) if text_encoder_outputs2_list[0] is not None else None
)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
if images[0] is not None:
@@ -1599,12 +1606,15 @@ class FineTuningDataset(BaseDataset):
# なければnpzを探す
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
abs_path = os.path.splitext(image_key)[0] + ".npz"
if not os.path.exists(abs_path):
abs_path = os.path.splitext(image_key)[0] + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if not os.path.exists(abs_path):
abs_path = os.path.join(subset.image_dir, image_key + ".npz")
if not os.path.exists(abs_path):
abs_path = os.path.join(subset.image_dir, image_key + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
if not os.path.exists(abs_path):
abs_path = None
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
@@ -1624,7 +1634,7 @@ class FineTuningDataset(BaseDataset):
if not subset.color_aug and not subset.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
image_info.latents_npz = self.image_key_to_npz_file(subset, image_key)
self.register_image(image_info, subset)
@@ -1638,7 +1648,7 @@ class FineTuningDataset(BaseDataset):
# check existence of all npz files
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents:
flip_aug_in_subset = False
# flip_aug_in_subset = False
npz_any = False
npz_all = True
@@ -1648,9 +1658,12 @@ class FineTuningDataset(BaseDataset):
has_npz = image_info.latents_npz is not None
npz_any = npz_any or has_npz
if subset.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
flip_aug_in_subset = True
# flip は同一の .npz 内に格納するようにした:
# そのためここでチェック漏れがあり実行時にエラーになる可能性があるので要検討
# if subset.flip_aug:
# has_npz = has_npz and image_info.latents_npz_flipped is not None
# flip_aug_in_subset = True
npz_all = npz_all and has_npz
if npz_any and not npz_all:
@@ -1664,8 +1677,8 @@ class FineTuningDataset(BaseDataset):
logger.warning(
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
)
if flip_aug_in_subset:
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# if flip_aug_in_subset:
# logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -1714,34 +1727,29 @@ class FineTuningDataset(BaseDataset):
# npz情報をきれいにしておく
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
image_info.latents_npz = None # image_info.latents_npz_flipped =
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"
npz_file_norm = base_name + ".npz"
if not os.path.exists(npz_file_norm):
npz_file_norm = base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if os.path.exists(npz_file_norm):
# image_key is full path
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
return npz_file_norm
# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None, None
return None
# image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
if not os.path.exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
npz_file_flip = None
npz_file_norm = os.path.join(subset.image_dir, base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
if os.path.exists(npz_file_norm):
return npz_file_norm
return npz_file_norm, npz_file_flip
return None
class ControlNetDataset(BaseDataset):
@@ -1943,17 +1951,26 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
self,
tokenizers,
text_encoders,
device,
weight_dtype,
cache_to_disk=False,
is_main_process=True,
cache_file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
dataset.cache_text_encoder_outputs(
tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
@@ -1986,8 +2003,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding()
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, divisor: int = 8) -> bool:
expected_latents_size = (reso[1] // divisor, reso[0] // divisor) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
return False
@@ -2079,7 +2096,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
if show_input_ids:
logger.info(f"input ids: {iid}")
if "input_ids2" in example:
if "input_ids2" in example and example["input_ids2"] is not None:
logger.info(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
@@ -2256,7 +2273,7 @@ def trim_and_resize_if_required(
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
vae: Union[AutoencoderKL, torch.nn.Module], cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
@@ -2311,23 +2328,36 @@ def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device)
input_ids2 = input_ids2.to(text_encoders[1].device) if input_ids2 is not None else None
with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
# TODO SDXL と Stable Cascade で統一する
if len(tokenizers) == 1:
# Stable Cascade
b_hidden_state1, b_pool2 = get_hidden_states_stable_cascade(
max_token_length, input_ids1, tokenizers[0], text_encoders[0], dtype
)
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
b_pool2 = b_pool2.detach().to("cpu") # b,1280
b_hidden_state2 = [None] * input_ids1.shape[0]
else:
# SDXL
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
# ここでcpuに移動しておかないと、上書きされてしまう
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
b_hidden_state2 = b_hidden_state2.detach().to("cpu") if b_hidden_state2[0] is not None else b_hidden_state2 # b,n*75+2,1280
b_pool2 = b_pool2.detach().to("cpu") # b,1280
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
@@ -2340,18 +2370,25 @@ def cache_batch_text_encoder_outputs(
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
np.savez(
npz_path,
hidden_state1=hidden_state1.cpu().float().numpy() if hidden_state1 is not None else None,
hidden_state2=hidden_state2.cpu().float().numpy(),
pool2=pool2.cpu().float().numpy(),
)
save_kwargs = {
"hidden_state1": hidden_state1.cpu().float().numpy(),
"pool2": pool2.cpu().float().numpy(),
}
if hidden_state2 is not None:
save_kwargs["hidden_state2"] = hidden_state2.cpu().float().numpy()
np.savez(npz_path, **save_kwargs)
# np.savez(
# npz_path,
# hidden_state1=hidden_state1.cpu().float().numpy(),
# hidden_state2=hidden_state2.cpu().float().numpy() if hidden_state2 is not None else None,
# pool2=pool2.cpu().float().numpy(),
# )
def load_text_encoder_outputs_from_disk(npz_path):
with np.load(npz_path) as f:
hidden_state1 = torch.from_numpy(f["hidden_state1"])
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f and f["hidden_state2"] is not None else None
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
return hidden_state1, hidden_state2, pool2
@@ -2706,6 +2743,7 @@ def add_tokenizer_arguments(parser: argparse.ArgumentParser):
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument(
@@ -3207,7 +3245,7 @@ def verify_training_args(args: argparse.Namespace):
print("highvram is enabled / highvramが有効です")
global HIGH_VRAM
HIGH_VRAM = True
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
logger.warning(
@@ -3219,7 +3257,9 @@ def verify_training_args(args: argparse.Namespace):
return
if args.v_parameterization and not args.v2:
logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
logger.warning(
"v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません"
)
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -4302,6 +4342,54 @@ def get_hidden_states_sdxl(
return hidden_states1, hidden_states2, pool2
def get_hidden_states_stable_cascade(
max_token_length: int,
input_ids2: torch.Tensor,
tokenizer2: CLIPTokenizer,
text_encoder2: CLIPTextModel,
weight_dtype: Optional[str] = None,
accelerator: Optional[Accelerator] = None,
):
# ここに Stable Cascade 用のコードがあるのはとても気持ち悪いが、変に整理するよりわかりやすいので、とりあえずこのまま
# It's very awkward to have Stable Cascade code here, but it's easier to understand than to organize it in a strange way, so for now it's as it is.
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids2.size()[0]
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-1] # ** last layer **
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
if weight_dtype is not None:
# this is required for additional network training
hidden_states2 = hidden_states2.to(weight_dtype)
return hidden_states2, pool2
def default_if_none(value, default):
return default if value is None else value

View File

@@ -17,19 +17,6 @@ from library import train_util
from library.sdxl_model_util import _load_state_dict_on_device
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
resolution_multiple = 42.67
latent_height = math.ceil(height / compression_factor_b)
latent_width = math.ceil(width / compression_factor_b)
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
latent_height = math.ceil(height / compression_factor_a)
latent_width = math.ceil(width / compression_factor_a)
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
return stage_c_latent_shape, stage_b_latent_shape
def main(args):
device = device_utils.get_preferred_device()
@@ -55,8 +42,6 @@ def main(args):
generator_b.eval().requires_grad_(False).to(loading_device)
# CLIP encoders
print(f"Loading CLIP text model")
tokenizer = sc_utils.load_tokenizer(args)
text_model = sc_utils.load_clip_text_model(
@@ -74,7 +59,7 @@ def main(args):
caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
height, width = 1024, 1024
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=1)
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
# 謎のクラス gdf
gdf_c = sc.GDF(
@@ -106,13 +91,25 @@ def main(args):
# extras_b.sampling_configs["t_start"] = 1.0
# PREPARE CONDITIONS
cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
input_ids = tokenizer(
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
input_ids = tokenizer([""], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")[
"input_ids"
].to(text_model.device)
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
zero_img_emb = torch.zeros(1, 768, device=device)

View File

@@ -140,13 +140,16 @@ def train(args):
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device)
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
# Previewer is small enough to be loaded on CPU
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu")
previewer.eval()
else:
previewer = None
# 学習を準備する
if cache_latents:
logger.info(
"Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`."
+ " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。"
)
# effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
with torch.no_grad():
@@ -155,7 +158,8 @@ def train(args):
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
cache_func=sc_utils.cache_batch_latents,
train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX,
32,
)
effnet.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -164,8 +168,8 @@ def train(args):
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.")
# stage_c.enable_gradient_checkpointing()
accelerator.print("enable gradient checkpointing")
stage_c.set_gradient_checkpointing(True)
train_stage_c = args.learning_rate > 0
train_text_encoder1 = False
@@ -176,9 +180,10 @@ def train(args):
text_encoder1.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
assert train_text_encoder1, "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
assert (
train_text_encoder1
), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
# caching one text encoder output is not supported
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
@@ -190,22 +195,16 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
raise NotImplementedError(
"Caching text encoder outputs is not supported in this version / text encoderの出力のキャッシュはサポートされていません"
)
print(
f"Please make sure that the text encoder outputs are cached before training with `stable_cascade_cache_text_encoder_outputs.py`."
+ " / 学習前に`stable_cascade_cache_text_encoder_outputs.py`でtext encoderの出力をキャッシュしてください。"
)
# Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer),
(text_encoder1),
(tokenizer,),
(text_encoder1,),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
)
accelerator.wait_for_everyone()
@@ -339,20 +338,13 @@ def train(args):
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=sc.AdaptiveLossWeight(),
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
)
# 以下2つの変数は、どうもデフォルトのままっぽい
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
# noise_scheduler = DDPMScheduler(
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
# )
# prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
# if args.zero_terminal_snr:
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
@@ -361,18 +353,8 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# # For --sample_at_first
# sdxl_train_util.sample_images(
# accelerator,
# args,
# 0,
# global_step,
# accelerator.device,
# effnet,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# stage_c,
# )
# For --sample_at_first
sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -404,11 +386,20 @@ def train(args):
# TODO support weighted captions
input_ids1 = input_ids1.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states, pool = sc.get_clip_conditions(None, input_ids1, tokenizer, text_encoder1)
encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade(
args.max_token_length,
input_ids1,
tokenizer,
text_encoder1,
None if not args.full_fp16 else weight_dtype,
accelerator,
)
else:
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280
# FORWARD PASS
with torch.no_grad():
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
@@ -421,7 +412,8 @@ def train(args):
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
loss_adjusted = (loss * loss_weight).mean()
gdf.loss_weight.update_buckets(logSNR, loss)
if args.adaptive_loss_weight:
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
accelerator.backward(loss_adjusted)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -439,17 +431,7 @@ def train(args):
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# effnet,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# stage_c,
# )
sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -467,7 +449,7 @@ def train(args):
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
@@ -502,17 +484,7 @@ def train(args):
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
# sdxl_train_util.sample_images(
# accelerator,
# args,
# epoch + 1,
# global_step,
# accelerator.device,
# effnet,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# stage_c,
# )
sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
is_main_process = accelerator.is_main_process
# if is_main_process:
@@ -540,6 +512,8 @@ def setup_parser() -> argparse.ArgumentParser:
sc_utils.add_effnet_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_text_model_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_training_arguments(parser)
train_util.add_tokenizer_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)

View File

@@ -145,17 +145,17 @@ def cache_to_disk(args: argparse.Namespace) -> None:
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + sc_utils.LATENTS_CACHE_SUFFIX
image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if args.skip_existing:
if sc_utils.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
sc_utils.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop, accelerator.device, effnet_dtype)
train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")

View File

@@ -149,8 +149,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
sc_utils.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, weight_dtype
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype
)
accelerator.wait_for_everyone()