Merge pull request #495 from kohya-ss/dev

dim from weights, fix multires noise, update gen script etc.
This commit is contained in:
Kohya S
2023-05-11 22:19:33 +09:00
committed by GitHub
11 changed files with 815 additions and 56 deletions

View File

@@ -138,6 +138,28 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### 11 May 2023, 2023/05/11
- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova!
- It is useful in combination with `resize_lora.py`. Please see the PR for details.
- Fixed a bug where the noise resolution was incorrect with Multires noise. [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) Thanks to sdbds!
- Please see the PR for details.
- The image generation scripts can now use img2img and highres fix at the same time.
- Fixed a bug where the hint image of ControlNet was incorrectly BGR instead of RGB in the image generation scripts.
- Added a feature to the image generation scripts to use the memory-efficient VAE.
- If you specify a number with the `--vae_slices` option, the memory-efficient VAE will be used. The maximum output size will be larger, but it will be slower. Please specify a value of about `16` or `32`.
- The implementation of the VAE is in `library/slicing_vae.py`.
- `train_network.py`にdim(rank)を重みファイルから自動決定するオプション`--dim_from_weights`が追加されました。 [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) AI-Casanova氏に感謝します。
- `resize_lora.py`と組み合わせると有用です。詳細はPRもご参照ください。
- Multires noiseでイズ解像度が正しくない不具合が修正されました。 [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) sdbds氏に感謝します。
- 詳細は当該PRをご参照ください。
- 生成スクリプトでimg2imgとhighres fixを同時に使用できるようにしました。
- 生成スクリプトでControlNetのhint画像が誤ってBGRだったのをRGBに修正しました。
- 生成スクリプトで省メモリ化VAEを使えるよう機能追加しました。
- `--vae_slices`オプションに数値を指定すると、省メモリ化VAEを用います。出力可能な最大サイズが大きくなりますが、遅くなります。`16`または`32`程度の値を指定してください。
- VAEの実装は`library/slicing_vae.py`にあります。
### 7 May 2023, 2023/05/07
- The documentation has been moved to the `docs` folder. If you have links, please change them.

View File

@@ -955,7 +955,7 @@ class PipelineLike:
if torch.cuda.is_available():
torch.cuda.empty_cache()
init_latents = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)
).latent_dist
@@ -2091,7 +2091,7 @@ def main(args):
dtype = torch.float32
highres_fix = args.highres_fix_scale is not None
assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
@@ -2250,7 +2250,27 @@ def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
# custom pipelineをコピったやつを生成する
if args.vae_slices:
from library.slicing_vae import SlicingAutoencoderKL
sli_vae = SlicingAutoencoderKL(
act_fn="silu",
block_out_channels=(128, 256, 512, 512),
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
in_channels=3,
latent_channels=4,
layers_per_block=2,
norm_num_groups=32,
out_channels=3,
sample_size=512,
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
num_slices=args.vae_slices,
)
sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする
vae = sli_vae
del sli_vae
vae.to(dtype).to(device)
text_encoder.to(dtype).to(device)
unet.to(dtype).to(device)
if clip_model is not None:
@@ -2262,7 +2282,7 @@ def main(args):
if args.network_module:
networks = []
network_default_muls = []
network_pre_calc=args.network_pre_calc
network_pre_calc = args.network_pre_calc
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
@@ -2592,12 +2612,18 @@ def main(args):
# 画像サイズにオプション指定があるときはリサイズする
if args.W is not None and args.H is not None:
# highres fix を考慮に入れる
w, h = args.W, args.H
if highres_fix:
w = int(w * args.highres_fix_scale + 0.5)
h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None:
print(f"resize img2img source images to {args.W}*{args.H}")
init_images = resize_images(init_images, (args.W, args.H))
print(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h))
if mask_images is not None:
print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H))
print(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h))
regional_network = False
if networks and mask_images:
@@ -2671,13 +2697,15 @@ def main(args):
width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32
strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength
ext_1st = BatchDataExt(
width_1st,
height_1st,
args.highres_fix_steps,
ext.scale,
ext.negative_scale,
ext.strength,
strength_1st,
ext.network_muls,
ext.num_sub_prompts,
)
@@ -2827,7 +2855,7 @@ def main(args):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
if not regional_network and network_pre_calc:
for n in networks:
n.restore_weights()
@@ -3032,14 +3060,16 @@ def main(args):
if init_images is not None:
init_image = init_images[global_step % len(init_images)]
# img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する
# 32単位に丸めたやつにresizeされるので踏襲する
width, height = init_image.size
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
if not highres_fix:
width, height = init_image.size
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
if mask_images is not None:
mask_image = mask_images[global_step % len(mask_images)]
@@ -3141,6 +3171,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率",
)
parser.add_argument(
"--vae_slices",
type=int,
default=None,
help=
"number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しないデフォルト、指定すると遅くなる。16か32程度を推奨"
)
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
parser.add_argument(
"--sampler",
@@ -3218,7 +3255,9 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
@@ -3276,6 +3315,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
)
parser.add_argument(
"--highres_fix_strength",
type=float,
default=None,
help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ",
)
parser.add_argument(
"--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
)

View File

@@ -346,14 +346,14 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.3):
b, c, w, h = noise.shape
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random() * 2 + 2 # Rather than always going 2x,
w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, w, h).to(device)) * discount**i
if w == 1 or h == 1:
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
if wn == 1 or hn == 1:
break # Lowest resolution is 1x1
return noise / noise.std() # Scaled back to roughly unit variance

665
library/slicing_vae.py Normal file
View File

@@ -0,0 +1,665 @@
# Modified from Diffusers to reduce VRAM usage
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block, ResnetBlock2D
from diffusers.models.vae import DecoderOutput, Encoder, AutoencoderKLOutput, DiagonalGaussianDistribution
def slice_h(x, num_slices):
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
# Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
# NCHWでもNHWCでもどちらでも動く
size = (x.shape[2] + num_slices - 1) // num_slices
sliced = []
for i in range(num_slices):
if i == 0:
sliced.append(x[:, :, : size + 1, :])
else:
end = size * (i + 1) + 1
if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
end = x.shape[2]
sliced.append(x[:, :, size * i - 1 : end, :])
if end >= x.shape[2]:
break
return sliced
def cat_h(sliced):
# padding分を除いて結合する
cat = []
for i, x in enumerate(sliced):
if i == 0:
cat.append(x[:, :, :-1, :])
elif i == len(sliced) - 1:
cat.append(x[:, :, 1:, :])
else:
cat.append(x[:, :, 1:-1, :])
del x
x = torch.cat(cat, dim=2)
return x
def resblock_forward(_self, num_slices, input_tensor, temb):
assert _self.upsample is None and _self.downsample is None
assert _self.norm1.num_groups == _self.norm2.num_groups
assert temb is None
# make sure norms are on cpu
org_device = input_tensor.device
cpu_device = torch.device("cpu")
_self.norm1.to(cpu_device)
_self.norm2.to(cpu_device)
# すべてのテンソルをCPUに移動する
input_tensor = input_tensor.to(cpu_device)
hidden_states = input_tensor
# どうもこれは結果が異なるようだ……
# def sliced_norm1(norm, x):
# num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
# sliced_tensor = torch.chunk(x, num_div, dim=1)
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# normed_tensor = []
# for i in range(num_div):
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
# normed_tensor.append(n)
# del n
# x = torch.cat(normed_tensor, dim=1)
# return num_div, x
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
hidden_states = _self.norm1(hidden_states) # run on cpu
sliced = slice_h(hidden_states, num_slices)
del hidden_states
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
# 計算する部分だけGPUに移動する、以下同様
x = x.to(org_device)
x = _self.nonlinearity(x)
x = _self.conv1(x)
x = x.to(cpu_device)
sliced[i] = x
del x
hidden_states = cat_h(sliced)
del sliced
hidden_states = _self.norm2(hidden_states) # run on cpu
sliced = slice_h(hidden_states, num_slices)
del hidden_states
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = _self.nonlinearity(x)
x = _self.dropout(x)
x = _self.conv2(x)
x = x.to(cpu_device)
sliced[i] = x
del x
hidden_states = cat_h(sliced)
del sliced
# make shortcut
if _self.conv_shortcut is not None:
sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
del input_tensor
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = _self.conv_shortcut(x)
x = x.to(cpu_device)
sliced[i] = x
del x
input_tensor = torch.cat(sliced, dim=2)
del sliced
output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
return output_tensor
class SlicingEncoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
num_slices=2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=norm_num_groups,
temb_channels=None,
)
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
# replace forward of ResBlocks
def wrapper(func, module, num_slices):
def forward(*args, **kwargs):
return func(module, num_slices, *args, **kwargs)
return forward
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
# print(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
# midblock doesn't have downsample
for i, down_block in enumerate(self.down_blocks[::-1]):
if div >= 2:
div = int(div)
# print(f"down block: {i} divisor: {div}")
for resnet in down_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if down_block.downsamplers is not None:
# print("has downsample")
for downsample in down_block.downsamplers:
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
div *= 2
def forward(self, x):
sample = x
del x
org_device = sample.device
cpu_device = torch.device("cpu")
# sample = self.conv_in(sample)
sample = sample.to(cpu_device)
sliced = slice_h(sample, self.num_slices)
del sample
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = self.conv_in(x)
x = x.to(cpu_device)
sliced[i] = x
del x
sample = cat_h(sliced)
del sliced
sample = sample.to(org_device)
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
# ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def downsample_forward(self, _self, num_slices, hidden_states):
assert hidden_states.shape[1] == _self.channels
assert _self.use_conv and _self.padding == 0
print("downsample forward", num_slices, hidden_states.shape)
org_device = hidden_states.device
cpu_device = torch.device("cpu")
hidden_states = hidden_states.to(cpu_device)
pad = (0, 1, 0, 1)
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
# slice with even number because of stride 2
# strideが2なので偶数でスライスする
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
size = (hidden_states.shape[2] + num_slices - 1) // num_slices
size = size + 1 if size % 2 == 1 else size
sliced = []
for i in range(num_slices):
if i == 0:
sliced.append(hidden_states[:, :, : size + 1, :])
else:
end = size * (i + 1) + 1
if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
end = hidden_states.shape[2]
sliced.append(hidden_states[:, :, size * i - 1 : end, :])
if end >= hidden_states.shape[2]:
break
del hidden_states
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = _self.conv(x)
x = x.to(cpu_device)
# ここだけ雰囲気が違うのはCopilotのせい
if i == 0:
hidden_states = x
else:
hidden_states = torch.cat([hidden_states, x], dim=2)
hidden_states = hidden_states.to(org_device)
# print("downsample forward done", hidden_states.shape)
return hidden_states
class SlicingDecoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
num_slices=2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=norm_num_groups,
temb_channels=None,
)
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
# replace forward of ResBlocks
def wrapper(func, module, num_slices):
def forward(*args, **kwargs):
return func(module, num_slices, *args, **kwargs)
return forward
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.up_blocks) - 1))
print(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
# midblock doesn't have upsample
for i, up_block in enumerate(self.up_blocks):
if div >= 2:
div = int(div)
# print(f"up block: {i} divisor: {div}")
for resnet in up_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if up_block.upsamplers is not None:
# print("has upsample")
for upsample in up_block.upsamplers:
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
div *= 2
def forward(self, z):
sample = z
del z
sample = self.conv_in(sample)
# middle
sample = self.mid_block(sample)
# up
for i, up_block in enumerate(self.up_blocks):
sample = up_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
# conv_out with slicing because of VRAM usage
# conv_outはとてもVRAM使うのでスライスして対応
org_device = sample.device
cpu_device = torch.device("cpu")
sample = sample.to(cpu_device)
sliced = slice_h(sample, self.num_slices)
del sample
for i in range(self.num_slices):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = self.conv_out(x)
x = x.to(cpu_device)
sliced[i] = x
sample = cat_h(sliced)
del sliced
sample = sample.to(org_device)
return sample
def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
assert hidden_states.shape[1] == _self.channels
assert _self.use_conv_transpose == False and _self.use_conv
org_dtype = hidden_states.dtype
org_device = hidden_states.device
cpu_device = torch.device("cpu")
hidden_states = hidden_states.to(cpu_device)
sliced = slice_h(hidden_states, num_slices)
del hidden_states
for i in range(num_slices):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
# PyTorch 2で直らないかね……
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
x = _self.conv(x)
# upsampleされてるのでpadは2になる
if i == 0:
x = x[:, :, :-2, :]
elif i == num_slices - 1:
x = x[:, :, 2:, :]
else:
x = x[:, :, 2:-2, :]
x = x.to(cpu_device)
sliced[i] = x
del x
hidden_states = torch.cat(sliced, dim=2)
# print("us hidden_states", hidden_states.shape)
del sliced
hidden_states = hidden_states.to(org_device)
return hidden_states
class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
and Max Welling.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
num_slices: int = 16,
):
super().__init__()
# pass init params to Encoder
self.encoder = SlicingEncoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
num_slices=num_slices,
)
# pass init params to Decoder
self.decoder = SlicingDecoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
num_slices=num_slices,
)
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# これはバッチ方向のスライシング 紛らわしい
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -94,7 +94,7 @@ def split(args):
filename, ext = os.path.splitext(args.save_to)
model_file_name = filename + f"-{new_rank:04d}{ext}"
print(f"\nsaving model to: {model_file_name}")
print(f"saving model to: {model_file_name}")
save_to_file(model_file_name, state_dict, new_metadata)

View File

@@ -73,7 +73,7 @@ class LoRAModule(torch.nn.Module):
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
@@ -319,6 +319,35 @@ class LoRAInfModule(LoRAModule):
return out
def parse_block_lr_kwargs(nw_kwargs):
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
# 以上のいずれにも設定がない場合は無効としてNoneを返す
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
return None, None, None
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
)
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None:
network_dim = 4 # default
@@ -337,9 +366,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
# block dim/alpha/lr
block_dims = kwargs.get("block_dims", None)
down_lr_weight = kwargs.get("down_lr_weight", None)
mid_lr_weight = kwargs.get("mid_lr_weight", None)
up_lr_weight = kwargs.get("up_lr_weight", None)
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
@@ -351,22 +378,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
@@ -634,6 +645,12 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
)
# block lr
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
return network, weights_sd
@@ -835,7 +852,7 @@ class LoRANetwork(torch.nn.Module):
print(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
self,
up_lr_weight: List[float] = None,

View File

@@ -193,12 +193,12 @@ def merge(args):
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"\nsaving SD model to: {args.save_to}")
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"\nsaving model to: {args.save_to}")
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)

View File

@@ -326,7 +326,7 @@ def resize(args):
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"\nsaving model to: {args.save_to}")
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -160,7 +160,7 @@ def merge(args):
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
print(f"\nsaving model to: {args.save_to}")
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype)

View File

@@ -62,7 +62,7 @@ def load_control_net(v2, unet, model):
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference")
print("ControlNet: loading difference:", is_difference)
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
@@ -123,7 +123,8 @@ def load_preprocess(prep_type: str):
def preprocess_ctrl_net_hint_image(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, ::-1].copy() # rgb to bgr
# ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている
# image = image[:, :, ::-1].copy() # rgb to bgr
image = image[None].transpose(0, 3, 1, 2) # nchw
image = torch.from_numpy(image)
return image # 0 to 1

View File

@@ -176,7 +176,10 @@ def train(args):
net_kwargs[key] = value
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
@@ -522,10 +525,11 @@ def train(args):
loss_total = 0.0
del train_dataset_group
# if hasattr(network, "on_step_start"):
# on_step_start = network.on_step_start
# else:
# on_step_start = lambda *args, **kwargs: None
# callback for step start
if hasattr(network, "on_step_start"):
on_step_start = network.on_step_start
else:
on_step_start = lambda *args, **kwargs: None
# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
@@ -560,7 +564,7 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
# on_step_start(text_encoder, unet)
on_step_start(text_encoder, unet)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
@@ -760,6 +764,11 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
)
parser.add_argument(
"--dim_from_weights",
action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
)
return parser