mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
update stable cascade stage C training #1119
This commit is contained in:
85
README.md
85
README.md
@@ -1,3 +1,88 @@
|
||||
# Trainng Stable Cascade Stage C
|
||||
|
||||
This is an experimental feature. There may be bugs.
|
||||
|
||||
## Usage
|
||||
|
||||
Training is run with `stable_cascade_train_stage_c.py`.
|
||||
|
||||
The main options are the same as `sdxl_train.py`. The following options have been added.
|
||||
|
||||
- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights.
|
||||
- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights.
|
||||
- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used.
|
||||
- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`.
|
||||
- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training.
|
||||
- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight.
|
||||
|
||||
The learning rate is set to 1e-4 in the official settings.
|
||||
|
||||
The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights.
|
||||
|
||||
Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images.
|
||||
|
||||
Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary).
|
||||
|
||||
Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight.
|
||||
|
||||
`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is.
|
||||
|
||||
A scale of about 4 is suitable for sample image generation.
|
||||
|
||||
Since the official settings use `bf16` for training, training with `fp16` may be unstable.
|
||||
|
||||
The code for training the Text Encoder is also written, but it is untested.
|
||||
|
||||
### About the dataset for fine tuning
|
||||
|
||||
If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance.
|
||||
|
||||
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
|
||||
|
||||
|
||||
# Stable Cascade Stage C の学習
|
||||
|
||||
実験的機能です。不具合があるかもしれません。
|
||||
|
||||
## 使い方
|
||||
|
||||
学習は `stable_cascade_train_stage_c.py` で行います。
|
||||
|
||||
主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。
|
||||
|
||||
- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。
|
||||
- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。
|
||||
- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。
|
||||
- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。
|
||||
- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。
|
||||
- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。
|
||||
|
||||
学習率は、公式の設定では 1e-4 のようです。
|
||||
|
||||
初回は `--text_model_checkpoint_path` と `--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。
|
||||
|
||||
学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。
|
||||
|
||||
SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EffcientNetEncoder に適用されます(mixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。
|
||||
|
||||
latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EffcientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。
|
||||
|
||||
メモリ消費を抑えるための `--gradient_checkpointing` 、`--full_bf16`、`--full_fp16`(未テスト)はそのまま使用できます。
|
||||
|
||||
サンプル画像生成時の Scale には 4 程度が適しているようです。
|
||||
|
||||
公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。
|
||||
|
||||
Text Encoder 学習のコードも書いてありますが、未テストです。
|
||||
|
||||
### fine tuning方式のデータセットについて
|
||||
|
||||
SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。
|
||||
|
||||
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
@@ -11,15 +11,19 @@ import cv2
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
@@ -42,7 +46,7 @@ def collate_fn_remove_corrupted(batch):
|
||||
return batch
|
||||
|
||||
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||
@@ -50,10 +54,11 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
base_name = image_key
|
||||
relative_path = ""
|
||||
|
||||
ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
if recursive and relative_path:
|
||||
return os.path.join(data_dir, relative_path, base_name) + ".npz"
|
||||
return os.path.join(data_dir, relative_path, base_name) + ext
|
||||
else:
|
||||
return os.path.join(data_dir, base_name) + ".npz"
|
||||
return os.path.join(data_dir, base_name) + ext
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -83,13 +88,20 @@ def main(args):
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
if not args.stable_cascade:
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
divisor = 8
|
||||
else:
|
||||
vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE)
|
||||
divisor = 32
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
assert (
|
||||
len(max_reso) == 2
|
||||
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(
|
||||
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
|
||||
@@ -154,6 +166,10 @@ def main(args):
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
|
||||
# 追加情報を記録
|
||||
metadata[image_key]["original_size"] = (image.width, image.height)
|
||||
metadata[image_key]["train_resized_size"] = resized_size
|
||||
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert (
|
||||
@@ -168,9 +184,9 @@ def main(args):
|
||||
), f"internal error resized size is small: {resized_size}, {reso}"
|
||||
|
||||
# 既に存在するファイルがあればshape等を確認して同じならskipする
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade)
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor):
|
||||
continue
|
||||
|
||||
# バッチへ追加
|
||||
@@ -208,7 +224,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
|
||||
parser.add_argument(
|
||||
"--stable_cascade",
|
||||
action="store_true",
|
||||
help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
@@ -231,10 +254,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
"--bucket_no_upscale",
|
||||
action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help="use mixed precision / 混合精度を使う場合、その精度",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
@@ -242,7 +271,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
|
||||
"--flip_aug",
|
||||
action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
# https://github.com/Stability-AI/StableCascade
|
||||
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
import torchvision
|
||||
|
||||
|
||||
@@ -125,6 +127,23 @@ class EfficientNetEncoder(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.mapper(self.backbone(x))
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def encode(self, x):
|
||||
"""
|
||||
VAE と同じように使えるようにするためのメソッド。正しくはちゃんと呼び出し側で分けるべきだが、暫定的な対応。
|
||||
The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
|
||||
"""
|
||||
# latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
x = self(x)
|
||||
return SimpleNamespace(latent_dist=SimpleNamespace(sample=lambda: x))
|
||||
|
||||
|
||||
# なんかわりと乱暴な実装(;'∀')
|
||||
# 一から学習することもないだろうから、無効化しておく
|
||||
@@ -136,6 +155,7 @@ class EfficientNetEncoder(nn.Module):
|
||||
# class Conv2d(torch.nn.Conv2d):
|
||||
# def reset_parameters(self):
|
||||
# return None
|
||||
|
||||
from torch.nn import Conv2d
|
||||
from torch.nn import Linear
|
||||
|
||||
@@ -187,7 +207,12 @@ class ResBlock(nn.Module):
|
||||
Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
|
||||
)
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_gradient_checkpointing(self, value):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward_body(self, x, x_skip=None):
|
||||
x_res = x
|
||||
x = self.norm(self.depthwise(x))
|
||||
if x_skip is not None:
|
||||
@@ -195,6 +220,22 @@ class ResBlock(nn.Module):
|
||||
x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x + x_res
|
||||
|
||||
def forward(self, x, x_skip=None):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("ResnetBlock2D: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
return func(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, x_skip)
|
||||
else:
|
||||
x = self.forward_body(x, x_skip)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
|
||||
@@ -204,11 +245,32 @@ class AttnBlock(nn.Module):
|
||||
self.attention = Attention2D(c, nhead, dropout)
|
||||
self.kv_mapper = nn.Sequential(nn.SiLU(), Linear(c_cond, c))
|
||||
|
||||
def forward(self, x, kv):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_gradient_checkpointing(self, value):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward_body(self, x, kv):
|
||||
kv = self.kv_mapper(kv)
|
||||
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
|
||||
return x
|
||||
|
||||
def forward(self, x, kv):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("AttnBlock: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
return func(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, kv)
|
||||
else:
|
||||
x = self.forward_body(x, kv)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FeedForwardBlock(nn.Module):
|
||||
def __init__(self, c, dropout=0.0):
|
||||
@@ -218,10 +280,31 @@ class FeedForwardBlock(nn.Module):
|
||||
Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_gradient_checkpointing(self, value):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward_body(self, x):
|
||||
x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("FeedForwardBlock: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
return func(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
|
||||
else:
|
||||
x = self.forward_body(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
def __init__(self, c, c_timestep, conds=["sca"]):
|
||||
@@ -250,9 +333,38 @@ class UpDownBlock2d(nn.Module):
|
||||
mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
|
||||
self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
|
||||
|
||||
self.mode = mode
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_gradient_checkpointing(self, value):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def forward_body(self, x):
|
||||
org_dtype = x.dtype
|
||||
for i, block in enumerate(self.blocks):
|
||||
# 公式の実装では、常に float で計算しているが、すこしでもメモリを節約するために bfloat16 + Upsample のみ float に変換する
|
||||
# In the official implementation, it always calculates in float, but for the sake of saving memory, it converts to float only for bfloat16 + Upsample
|
||||
if x.dtype == torch.bfloat16 and (self.mode == "up" and i == 0 or self.mode != "up" and i == 1):
|
||||
x = x.float()
|
||||
x = block(x)
|
||||
x = x.to(org_dtype)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x.float())
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("UpDownBlock2d: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
return func(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
|
||||
else:
|
||||
x = self.forward_body(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -790,6 +902,12 @@ class StageC(nn.Module):
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def set_gradient_checkpointing(self, value):
|
||||
for block in self.down_blocks + self.up_blocks:
|
||||
for layer in block:
|
||||
if hasattr(layer, "set_gradient_checkpointing"):
|
||||
layer.set_gradient_checkpointing(value)
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
@@ -900,8 +1018,62 @@ class StageC(nn.Module):
|
||||
for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
|
||||
self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
||||
class Previewer(nn.Module):
|
||||
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
||||
super().__init__()
|
||||
self.blocks = nn.Sequential(
|
||||
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden),
|
||||
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 2),
|
||||
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
||||
nn.GELU(),
|
||||
nn.BatchNorm2d(c_hidden // 4),
|
||||
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.blocks(x)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model):
|
||||
# deprecated
|
||||
|
||||
# self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
|
||||
# is_eval の処理をここでやるのは微妙なので別のところでやる
|
||||
# is_unconditional もここでやるのは微妙なので別のところでやる
|
||||
@@ -921,39 +1093,6 @@ def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, tex
|
||||
# return {"clip_text": text_embeddings, "clip_text_pooled": text_pooled_embeddings} # , "clip_img": image_embeddings}
|
||||
|
||||
|
||||
def get_stage_c_conditions(
|
||||
batch: dict, effnet, effnet_preprocess, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
|
||||
):
|
||||
images = batch.get("images", None)
|
||||
|
||||
if images is not None:
|
||||
images = images.to(self.device)
|
||||
if is_eval and not is_unconditional:
|
||||
effnet_embeddings = effnet(effnet_preprocess(images))
|
||||
else:
|
||||
if is_eval:
|
||||
effnet_factor = 1
|
||||
else:
|
||||
effnet_factor = np.random.uniform(0.5, 1) # f64 to f32
|
||||
effnet_height, effnet_width = int(((images.size(-2) * effnet_factor) // 32) * 32), int(
|
||||
((images.size(-1) * effnet_factor) // 32) * 32
|
||||
)
|
||||
|
||||
effnet_embeddings = torch.zeros(images.size(0), 16, effnet_height // 32, effnet_width // 32, device=self.device)
|
||||
if not is_eval:
|
||||
effnet_images = torchvision.transforms.functional.resize(
|
||||
images, (effnet_height, effnet_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
rand_idx = np.random.rand(len(images)) <= 0.9
|
||||
if any(rand_idx):
|
||||
effnet_embeddings[rand_idx] = effnet(effnet_preprocess(effnet_images[rand_idx]))
|
||||
else:
|
||||
effnet_embeddings = None
|
||||
|
||||
return effnet_embeddings
|
||||
# {"effnet": effnet_embeddings, "clip": conditions["clip_text_pooled"]}
|
||||
|
||||
|
||||
# region gdf
|
||||
|
||||
|
||||
|
||||
@@ -1,28 +1,30 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate import init_empty_weights, Accelerator, PartialState
|
||||
from PIL import Image
|
||||
|
||||
from library import stable_cascade as sc
|
||||
from library.train_util import (
|
||||
ImageInfo,
|
||||
load_image,
|
||||
trim_and_resize_if_required,
|
||||
save_latents_to_disk,
|
||||
HIGH_VRAM,
|
||||
save_text_encoder_outputs_to_disk,
|
||||
)
|
||||
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.train_util import save_sd_model_on_epoch_end_or_stepwise_common, save_sd_model_on_train_end_common
|
||||
from library.train_util import (
|
||||
save_sd_model_on_epoch_end_or_stepwise_common,
|
||||
save_sd_model_on_train_end_common,
|
||||
line_to_prompt_dict,
|
||||
get_hidden_states_stable_cascade,
|
||||
)
|
||||
from library import sai_model_spec
|
||||
|
||||
|
||||
@@ -41,7 +43,22 @@ EFFNET_PREPROCESS = torchvision.transforms.Compose(
|
||||
)
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
|
||||
LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
|
||||
|
||||
|
||||
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
|
||||
resolution_multiple = 42.67
|
||||
latent_height = math.ceil(height / compression_factor_b)
|
||||
latent_width = math.ceil(width / compression_factor_b)
|
||||
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
|
||||
|
||||
latent_height = math.ceil(height / compression_factor_a)
|
||||
latent_width = math.ceil(width / compression_factor_a)
|
||||
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
|
||||
|
||||
return stage_c_latent_shape, stage_b_latent_shape
|
||||
|
||||
|
||||
# region load and save
|
||||
|
||||
|
||||
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
|
||||
@@ -165,153 +182,15 @@ def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.
|
||||
return stage_a
|
||||
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
expected_latents_size = (reso[1] // 32, reso[0] // 32) # bucket_resoはWxHなので注意
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
|
||||
return False
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if flip_aug:
|
||||
if "latents_flipped" not in npz:
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
effnet: sc.EfficientNetEncoder,
|
||||
cache_to_disk: bool,
|
||||
image_infos: List[ImageInfo],
|
||||
flip_aug: bool,
|
||||
random_crop: bool,
|
||||
device,
|
||||
dtype,
|
||||
) -> None:
|
||||
r"""
|
||||
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||
optionally requires image_infos to have: image
|
||||
if cache_to_disk is True, set info.latents_npz
|
||||
flipped latents is also saved if flip_aug is True
|
||||
if cache_to_disk is False, set info.latents
|
||||
latents_flipped is also set if flip_aug is True
|
||||
latents_original_size and latents_crop_ltrb are also set
|
||||
"""
|
||||
images = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image = EFFNET_PREPROCESS(image)
|
||||
images.append(image)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=device, dtype=dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = effnet(img_tensors).to("cpu")
|
||||
print(latents.shape)
|
||||
|
||||
if flip_aug:
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = effnet(img_tensors).to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents)
|
||||
|
||||
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
|
||||
# check NaN
|
||||
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
if not HIGH_VRAM:
|
||||
clean_memory_on_device(device)
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids, dtype):
|
||||
# 75 トークン越えは未対応
|
||||
input_ids = input_ids.to(text_encoders[0].device)
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state, b_pool = sc.get_clip_conditions(None, input_ids, tokenizers[0], text_encoders[0])
|
||||
|
||||
b_hidden_state = b_hidden_state.detach().to("cpu") # b,n*75+2,768
|
||||
b_pool = b_pool.detach().to("cpu") # b,1280
|
||||
|
||||
for info, hidden_state, pool in zip(image_infos, b_hidden_state, b_pool):
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, None, hidden_state, pool)
|
||||
else:
|
||||
info.text_encoder_outputs1 = hidden_state
|
||||
info.text_encoder_pool2 = pool
|
||||
|
||||
|
||||
def add_effnet_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--effnet_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
|
||||
def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer:
|
||||
logger.info(f"Loading Previewer from {previewer_checkpoint_path}")
|
||||
previewer = sc.Previewer().to(device)
|
||||
previewer_checkpoint = load_file(previewer_checkpoint_path)
|
||||
info = previewer.load_state_dict(
|
||||
previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"]
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_text_model_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--text_model_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
|
||||
)
|
||||
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_a_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_a_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_b_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_b_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_c_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_c_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
logger.info(info)
|
||||
return previewer
|
||||
|
||||
|
||||
def get_sai_model_spec(args):
|
||||
@@ -353,7 +232,7 @@ def get_sai_model_spec(args):
|
||||
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items}
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
|
||||
@@ -403,112 +282,334 @@ def save_stage_c_model_on_end(
|
||||
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
|
||||
|
||||
|
||||
def cache_latents(self, effnet, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
# endregion
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
# region sample generation
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
previewer,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
gdf,
|
||||
prompt_replacement=None,
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + LATENTS_CACHE_SUFFIX
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(effnet, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
stage_c = accelerator.unwrap_model(stage_c)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# read prompts
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# preprocess prompts
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
|
||||
# save random state to restore later
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
|
||||
# with torch.cuda.device(torch.cuda.current_device()):
|
||||
# torch.cuda.empty_cache()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
|
||||
# weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる
|
||||
def cache_text_encoder_outputs(self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True):
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
image_infos = list(self.image_data.values())
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
tokenizer,
|
||||
text_model,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 1024)
|
||||
height = prompt_dict.get("height", 1024)
|
||||
scale = prompt_dict.get("scale", 4)
|
||||
seed = prompt_dict.get("seed")
|
||||
# controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
image_infos_to_cache = []
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
if os.path.exists(te_out_npz):
|
||||
continue
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
|
||||
image_infos_to_cache.append(info)
|
||||
negative_prompt = "" if negative_prompt is None else negative_prompt
|
||||
cfg = scale
|
||||
timesteps = sample_steps
|
||||
shift = 2
|
||||
t_start = 1.0
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# prepare tokenizers and text encoders
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(device)
|
||||
if weight_dtype is not None:
|
||||
text_encoder.to(dtype=weight_dtype)
|
||||
# PREPARE CONDITIONS
|
||||
input_ids = tokenizer(
|
||||
[prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
|
||||
|
||||
# create batch
|
||||
batch = []
|
||||
batches = []
|
||||
for info in image_infos_to_cache:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
batch.append((info, input_ids1, None))
|
||||
input_ids = tokenizer(
|
||||
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
device = accelerator.device
|
||||
dtype = stage_c.dtype
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||
logger.info("caching text encoder outputs...")
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, weight_dtype
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
|
||||
|
||||
with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_c = gdf.sample(
|
||||
stage_c,
|
||||
conditions,
|
||||
stage_c_latent_shape,
|
||||
unconditions,
|
||||
device=device,
|
||||
cfg=cfg,
|
||||
shift=shift,
|
||||
timesteps=timesteps,
|
||||
t_start=t_start,
|
||||
)
|
||||
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
|
||||
sampled_c = sampled_c
|
||||
|
||||
sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype)
|
||||
image = previewer(sampled_c)[0]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
image = image.cpu().numpy().transpose(1, 2, 0)
|
||||
image = image * 255
|
||||
image = image.astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = prompt_dict["enum"]
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def add_effnet_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--effnet_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_text_model_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--text_model_checkpoint_path",
|
||||
type=str,
|
||||
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
|
||||
)
|
||||
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_a_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_a_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_b_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_b_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_c_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_c_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_previewer_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--previewer_checkpoint_path",
|
||||
type=str,
|
||||
required=False,
|
||||
help="path to previewer checkpoint / previewerのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_training_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--adaptive_loss_weight",
|
||||
action="store_true",
|
||||
help="if specified, use adaptive loss weight. if not, use P2 loss weight"
|
||||
+ " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する",
|
||||
)
|
||||
|
||||
@@ -133,6 +133,7 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
)
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
STABLE_CASCADE_LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
@@ -856,7 +857,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
self.buckets_indices: List[BucketBatchIndex] = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
@@ -910,8 +911,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
def cache_latents(self, vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
@@ -931,11 +932,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
|
||||
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug, divisor)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
@@ -967,9 +968,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
|
||||
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
|
||||
):
|
||||
assert len(tokenizers) == 2, "only support SDXL"
|
||||
"""
|
||||
最後の Text Encoder の pool がキャッシュされる。
|
||||
The last Text Encoder's pool is cached.
|
||||
"""
|
||||
# assert len(tokenizers) == 2, "only support SDXL"
|
||||
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
@@ -981,7 +986,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
@@ -1006,7 +1011,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
batches = []
|
||||
for info in image_infos_to_cache:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) if len(tokenizers) > 1 else None
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
|
||||
if len(batch) >= self.batch_size:
|
||||
@@ -1021,7 +1026,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
|
||||
)
|
||||
@@ -1270,7 +1275,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
|
||||
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
|
||||
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
|
||||
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
||||
example["text_encoder_outputs2_list"] = (
|
||||
torch.stack(text_encoder_outputs2_list) if text_encoder_outputs2_list[0] is not None else None
|
||||
)
|
||||
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
||||
|
||||
if images[0] is not None:
|
||||
@@ -1599,12 +1606,15 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = os.path.splitext(image_key)[0] + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = os.path.join(subset.image_dir, image_key + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
|
||||
if not os.path.exists(abs_path):
|
||||
abs_path = None
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
|
||||
@@ -1624,7 +1634,7 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
image_info.latents_npz = self.image_key_to_npz_file(subset, image_key)
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
@@ -1638,7 +1648,7 @@ class FineTuningDataset(BaseDataset):
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
# flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
@@ -1648,9 +1658,12 @@ class FineTuningDataset(BaseDataset):
|
||||
has_npz = image_info.latents_npz is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
# flip は同一の .npz 内に格納するようにした:
|
||||
# そのためここでチェック漏れがあり実行時にエラーになる可能性があるので要検討
|
||||
# if subset.flip_aug:
|
||||
# has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
# flip_aug_in_subset = True
|
||||
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
@@ -1664,8 +1677,8 @@ class FineTuningDataset(BaseDataset):
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# if flip_aug_in_subset:
|
||||
# logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
@@ -1714,34 +1727,29 @@ class FineTuningDataset(BaseDataset):
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
image_info.latents_npz = None # image_info.latents_npz_flipped =
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
npz_file_norm = base_name + ".npz"
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
return npz_file_norm
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
return None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
npz_file_norm = os.path.join(subset.image_dir, base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
|
||||
if os.path.exists(npz_file_norm):
|
||||
return npz_file_norm
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
@@ -1943,17 +1951,26 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
self,
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
device,
|
||||
weight_dtype,
|
||||
cache_to_disk=False,
|
||||
is_main_process=True,
|
||||
cache_file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||
dataset.cache_text_encoder_outputs(
|
||||
tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
|
||||
)
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
@@ -1986,8 +2003,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.disable_token_padding()
|
||||
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, divisor: int = 8) -> bool:
|
||||
expected_latents_size = (reso[1] // divisor, reso[0] // divisor) # bucket_resoはWxHなので注意
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
@@ -2079,7 +2096,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
|
||||
if show_input_ids:
|
||||
logger.info(f"input ids: {iid}")
|
||||
if "input_ids2" in example:
|
||||
if "input_ids2" in example and example["input_ids2"] is not None:
|
||||
logger.info(f"input ids2: {example['input_ids2'][j]}")
|
||||
if example["images"] is not None:
|
||||
im = example["images"][j]
|
||||
@@ -2256,7 +2273,7 @@ def trim_and_resize_if_required(
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
|
||||
vae: Union[AutoencoderKL, torch.nn.Module], cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
|
||||
) -> None:
|
||||
r"""
|
||||
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||
@@ -2311,23 +2328,36 @@ def cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
|
||||
):
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device) if input_ids2 is not None else None
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
dtype,
|
||||
)
|
||||
# TODO SDXL と Stable Cascade で統一する
|
||||
if len(tokenizers) == 1:
|
||||
# Stable Cascade
|
||||
b_hidden_state1, b_pool2 = get_hidden_states_stable_cascade(
|
||||
max_token_length, input_ids1, tokenizers[0], text_encoders[0], dtype
|
||||
)
|
||||
|
||||
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
|
||||
b_pool2 = b_pool2.detach().to("cpu") # b,1280
|
||||
|
||||
b_hidden_state2 = [None] * input_ids1.shape[0]
|
||||
else:
|
||||
# SDXL
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
dtype,
|
||||
)
|
||||
|
||||
# ここでcpuに移動しておかないと、上書きされてしまう
|
||||
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
|
||||
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
|
||||
b_hidden_state2 = b_hidden_state2.detach().to("cpu") if b_hidden_state2[0] is not None else b_hidden_state2 # b,n*75+2,1280
|
||||
b_pool2 = b_pool2.detach().to("cpu") # b,1280
|
||||
|
||||
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
|
||||
@@ -2340,18 +2370,25 @@ def cache_batch_text_encoder_outputs(
|
||||
|
||||
|
||||
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
|
||||
np.savez(
|
||||
npz_path,
|
||||
hidden_state1=hidden_state1.cpu().float().numpy() if hidden_state1 is not None else None,
|
||||
hidden_state2=hidden_state2.cpu().float().numpy(),
|
||||
pool2=pool2.cpu().float().numpy(),
|
||||
)
|
||||
save_kwargs = {
|
||||
"hidden_state1": hidden_state1.cpu().float().numpy(),
|
||||
"pool2": pool2.cpu().float().numpy(),
|
||||
}
|
||||
if hidden_state2 is not None:
|
||||
save_kwargs["hidden_state2"] = hidden_state2.cpu().float().numpy()
|
||||
np.savez(npz_path, **save_kwargs)
|
||||
# np.savez(
|
||||
# npz_path,
|
||||
# hidden_state1=hidden_state1.cpu().float().numpy(),
|
||||
# hidden_state2=hidden_state2.cpu().float().numpy() if hidden_state2 is not None else None,
|
||||
# pool2=pool2.cpu().float().numpy(),
|
||||
# )
|
||||
|
||||
|
||||
def load_text_encoder_outputs_from_disk(npz_path):
|
||||
with np.load(npz_path) as f:
|
||||
hidden_state1 = torch.from_numpy(f["hidden_state1"])
|
||||
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
|
||||
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f and f["hidden_state2"] is not None else None
|
||||
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
|
||||
return hidden_state1, hidden_state2, pool2
|
||||
|
||||
@@ -2706,6 +2743,7 @@ def add_tokenizer_arguments(parser: argparse.ArgumentParser):
|
||||
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
|
||||
)
|
||||
|
||||
|
||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
# for pretrained models
|
||||
parser.add_argument(
|
||||
@@ -3207,7 +3245,7 @@ def verify_training_args(args: argparse.Namespace):
|
||||
print("highvram is enabled / highvramが有効です")
|
||||
global HIGH_VRAM
|
||||
HIGH_VRAM = True
|
||||
|
||||
|
||||
if args.cache_latents_to_disk and not args.cache_latents:
|
||||
args.cache_latents = True
|
||||
logger.warning(
|
||||
@@ -3219,7 +3257,9 @@ def verify_training_args(args: argparse.Namespace):
|
||||
return
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
logger.warning("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません")
|
||||
logger.warning(
|
||||
"v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません"
|
||||
)
|
||||
if args.v2 and args.clip_skip is not None:
|
||||
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
@@ -4302,6 +4342,54 @@ def get_hidden_states_sdxl(
|
||||
return hidden_states1, hidden_states2, pool2
|
||||
|
||||
|
||||
def get_hidden_states_stable_cascade(
|
||||
max_token_length: int,
|
||||
input_ids2: torch.Tensor,
|
||||
tokenizer2: CLIPTokenizer,
|
||||
text_encoder2: CLIPTextModel,
|
||||
weight_dtype: Optional[str] = None,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
):
|
||||
# ここに Stable Cascade 用のコードがあるのはとても気持ち悪いが、変に整理するよりわかりやすいので、とりあえずこのまま
|
||||
# It's very awkward to have Stable Cascade code here, but it's easier to understand than to organize it in a strange way, so for now it's as it is.
|
||||
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids2.size()[0]
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-1] # ** last layer **
|
||||
|
||||
# pool2 = enc_out["text_embeds"]
|
||||
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
|
||||
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||
|
||||
if max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
hidden_states2 = torch.cat(states_list, dim=1)
|
||||
|
||||
# pool はnの最初のものを使う
|
||||
pool2 = pool2[::n_size]
|
||||
|
||||
if weight_dtype is not None:
|
||||
# this is required for additional network training
|
||||
hidden_states2 = hidden_states2.to(weight_dtype)
|
||||
|
||||
return hidden_states2, pool2
|
||||
|
||||
|
||||
def default_if_none(value, default):
|
||||
return default if value is None else value
|
||||
|
||||
|
||||
@@ -17,19 +17,6 @@ from library import train_util
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
|
||||
|
||||
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
|
||||
resolution_multiple = 42.67
|
||||
latent_height = math.ceil(height / compression_factor_b)
|
||||
latent_width = math.ceil(width / compression_factor_b)
|
||||
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
|
||||
|
||||
latent_height = math.ceil(height / compression_factor_a)
|
||||
latent_width = math.ceil(width / compression_factor_a)
|
||||
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
|
||||
|
||||
return stage_c_latent_shape, stage_b_latent_shape
|
||||
|
||||
|
||||
def main(args):
|
||||
device = device_utils.get_preferred_device()
|
||||
|
||||
@@ -55,8 +42,6 @@ def main(args):
|
||||
generator_b.eval().requires_grad_(False).to(loading_device)
|
||||
|
||||
# CLIP encoders
|
||||
print(f"Loading CLIP text model")
|
||||
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
|
||||
text_model = sc_utils.load_clip_text_model(
|
||||
@@ -74,7 +59,7 @@ def main(args):
|
||||
|
||||
caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
||||
height, width = 1024, 1024
|
||||
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=1)
|
||||
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# 謎のクラス gdf
|
||||
gdf_c = sc.GDF(
|
||||
@@ -106,13 +91,25 @@ def main(args):
|
||||
# extras_b.sampling_configs["t_start"] = 1.0
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
||||
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
||||
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
||||
input_ids = tokenizer([""], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")[
|
||||
"input_ids"
|
||||
].to(text_model.device)
|
||||
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
|
||||
@@ -140,13 +140,16 @@ def train(args):
|
||||
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=weight_dtype, device=loading_device)
|
||||
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
|
||||
|
||||
if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||
# Previewer is small enough to be loaded on CPU
|
||||
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu")
|
||||
previewer.eval()
|
||||
else:
|
||||
previewer = None
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
logger.info(
|
||||
"Please make sure that the latents are cached before training with `stable_cascade_cache_latents.py`."
|
||||
+ " / 学習前に`stable_cascade_cache_latents.py`でlatentをキャッシュしてください。"
|
||||
)
|
||||
# effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
with torch.no_grad():
|
||||
@@ -155,7 +158,8 @@ def train(args):
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
cache_func=sc_utils.cache_batch_latents,
|
||||
train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX,
|
||||
32,
|
||||
)
|
||||
effnet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -164,8 +168,8 @@ def train(args):
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.gradient_checkpointing:
|
||||
logger.warn("Gradient checkpointing is not supported for stage_c. Ignoring the option.")
|
||||
# stage_c.enable_gradient_checkpointing()
|
||||
accelerator.print("enable gradient checkpointing")
|
||||
stage_c.set_gradient_checkpointing(True)
|
||||
|
||||
train_stage_c = args.learning_rate > 0
|
||||
train_text_encoder1 = False
|
||||
@@ -176,9 +180,10 @@ def train(args):
|
||||
text_encoder1.gradient_checkpointing_enable()
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
assert train_text_encoder1, "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
|
||||
assert (
|
||||
train_text_encoder1
|
||||
), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
|
||||
|
||||
# caching one text encoder output is not supported
|
||||
if not train_text_encoder1:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(train_text_encoder1)
|
||||
@@ -190,22 +195,16 @@ def train(args):
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
raise NotImplementedError(
|
||||
"Caching text encoder outputs is not supported in this version / text encoderの出力のキャッシュはサポートされていません"
|
||||
)
|
||||
print(
|
||||
f"Please make sure that the text encoder outputs are cached before training with `stable_cascade_cache_text_encoder_outputs.py`."
|
||||
+ " / 学習前に`stable_cascade_cache_text_encoder_outputs.py`でtext encoderの出力をキャッシュしてください。"
|
||||
)
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer),
|
||||
(text_encoder1),
|
||||
(tokenizer,),
|
||||
(text_encoder1,),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -339,20 +338,13 @@ def train(args):
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=sc.AdaptiveLossWeight(),
|
||||
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
|
||||
)
|
||||
|
||||
# 以下2つの変数は、どうもデフォルトのままっぽい
|
||||
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
|
||||
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
|
||||
|
||||
# noise_scheduler = DDPMScheduler(
|
||||
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
# )
|
||||
# prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
# if args.zero_terminal_snr:
|
||||
# custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
@@ -361,18 +353,8 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# # For --sample_at_first
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# 0,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# effnet,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# stage_c,
|
||||
# )
|
||||
# For --sample_at_first
|
||||
sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -404,11 +386,20 @@ def train(args):
|
||||
# TODO support weighted captions
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
# unwrap_model is fine for models not wrapped by accelerator
|
||||
encoder_hidden_states, pool = sc.get_clip_conditions(None, input_ids1, tokenizer, text_encoder1)
|
||||
encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
tokenizer,
|
||||
text_encoder1,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280
|
||||
|
||||
# FORWARD PASS
|
||||
with torch.no_grad():
|
||||
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
|
||||
@@ -421,7 +412,8 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
|
||||
loss_adjusted = (loss * loss_weight).mean()
|
||||
|
||||
gdf.loss_weight.update_buckets(logSNR, loss)
|
||||
if args.adaptive_loss_weight:
|
||||
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
|
||||
|
||||
accelerator.backward(loss_adjusted)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
@@ -439,17 +431,7 @@ def train(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# None,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# effnet,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# stage_c,
|
||||
# )
|
||||
sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -467,7 +449,7 @@ def train(args):
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
@@ -502,17 +484,7 @@ def train(args):
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
# sdxl_train_util.sample_images(
|
||||
# accelerator,
|
||||
# args,
|
||||
# epoch + 1,
|
||||
# global_step,
|
||||
# accelerator.device,
|
||||
# effnet,
|
||||
# [tokenizer1, tokenizer2],
|
||||
# [text_encoder1, text_encoder2],
|
||||
# stage_c,
|
||||
# )
|
||||
sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
@@ -540,6 +512,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
sc_utils.add_stage_c_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
sc_utils.add_previewer_arguments(parser)
|
||||
sc_utils.add_training_arguments(parser)
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
|
||||
@@ -145,17 +145,17 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
image_info.image = image
|
||||
image_info.bucket_reso = bucket_reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.latents_npz = os.path.splitext(absolute_path)[0] + sc_utils.LATENTS_CACHE_SUFFIX
|
||||
image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
|
||||
if args.skip_existing:
|
||||
if sc_utils.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32):
|
||||
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
sc_utils.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop, accelerator.device, effnet_dtype)
|
||||
train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
@@ -149,8 +149,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
if len(image_infos) > 0:
|
||||
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
||||
sc_utils.cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, weight_dtype
|
||||
train_util.cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user