From 76a2b14cdb65f04722d2d6c551e11ff24c355fbd Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Sat, 6 May 2023 20:06:02 +0000 Subject: [PATCH 1/8] Instantiate size_from_weights --- train_network.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4c4cc281..1fe9d083 100644 --- a/train_network.py +++ b/train_network.py @@ -176,7 +176,32 @@ def train(args): net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + if args.size_from_weights: + network, weights = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet) + if net_kwargs is not None: + down_lr_weight = net_kwargs.get("down_lr_weight", None) + mid_lr_weight = net_kwargs.get("mid_lr_weight", None) + up_lr_weight = net_kwargs.get("up_lr_weight", None) + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = network_module.get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(net_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + else: + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return @@ -760,6 +785,10 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" ) + parser.add_argument( + "--size_from_weights", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + ) + return parser From 1b4bdff331e596ea5f9288d80a056a861c1d7ad1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 10 May 2023 23:09:25 +0900 Subject: [PATCH 2/8] enable i2i with highres fix, add slicing VAE --- gen_img_diffusers.py | 79 +++-- library/slicing_vae.py | 665 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 727 insertions(+), 17 deletions(-) create mode 100644 library/slicing_vae.py diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 60a24972..99e94cae 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -955,7 +955,7 @@ class PipelineLike: if torch.cuda.is_available(): torch.cuda.empty_cache() init_latents = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0) ).latent_dist @@ -2091,7 +2091,7 @@ def main(args): dtype = torch.float32 highres_fix = args.highres_fix_scale is not None - assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" if args.v_parameterization and not args.v2: print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") @@ -2250,7 +2250,27 @@ def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae vae.to(dtype).to(device) + text_encoder.to(dtype).to(device) unet.to(dtype).to(device) if clip_model is not None: @@ -2262,7 +2282,7 @@ def main(args): if args.network_module: networks = [] network_default_muls = [] - network_pre_calc=args.network_pre_calc + network_pre_calc = args.network_pre_calc for i, network_module in enumerate(args.network_module): print("import network module:", network_module) @@ -2592,12 +2612,18 @@ def main(args): # 画像サイズにオプション指定があるときはリサイズする if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + if init_images is not None: - print(f"resize img2img source images to {args.W}*{args.H}") - init_images = resize_images(init_images, (args.W, args.H)) + print(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) if mask_images is not None: - print(f"resize img2img mask images to {args.W}*{args.H}") - mask_images = resize_images(mask_images, (args.W, args.H)) + print(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) regional_network = False if networks and mask_images: @@ -2671,13 +2697,15 @@ def main(args): width_1st = width_1st - width_1st % 32 height_1st = height_1st - height_1st % 32 + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + ext_1st = BatchDataExt( width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, - ext.strength, + strength_1st, ext.network_muls, ext.num_sub_prompts, ) @@ -2827,7 +2855,7 @@ def main(args): n.set_multiplier(m) if regional_network: n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) - + if not regional_network and network_pre_calc: for n in networks: n.restore_weights() @@ -3032,14 +3060,16 @@ def main(args): if init_images is not None: init_image = init_images[global_step % len(init_images)] + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する # 32単位に丸めたやつにresizeされるので踏襲する - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - print( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) if mask_images is not None: mask_image = mask_images[global_step % len(mask_images)] @@ -3141,6 +3171,13 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help= + "number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨" + ) parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") parser.add_argument( "--sampler", @@ -3218,7 +3255,9 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") - parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する") + parser.add_argument( + "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + ) parser.add_argument( "--textual_inversion_embeddings", type=str, @@ -3276,6 +3315,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) parser.add_argument( "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" ) diff --git a/library/slicing_vae.py b/library/slicing_vae.py new file mode 100644 index 00000000..084bff68 --- /dev/null +++ b/library/slicing_vae.py @@ -0,0 +1,665 @@ +# Modified from Diffusers to reduce VRAM usage + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block, ResnetBlock2D +from diffusers.models.vae import DecoderOutput, Encoder, AutoencoderKLOutput, DiagonalGaussianDistribution + + +def slice_h(x, num_slices): + # slice with pad 1 both sides: to eliminate side effect of padding of conv2d + # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする + # NCHWでもNHWCでもどちらでも動く + size = (x.shape[2] + num_slices - 1) // num_slices + sliced = [] + for i in range(num_slices): + if i == 0: + sliced.append(x[:, :, : size + 1, :]) + else: + end = size * (i + 1) + 1 + if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う + end = x.shape[2] + sliced.append(x[:, :, size * i - 1 : end, :]) + if end >= x.shape[2]: + break + return sliced + + +def cat_h(sliced): + # padding分を除いて結合する + cat = [] + for i, x in enumerate(sliced): + if i == 0: + cat.append(x[:, :, :-1, :]) + elif i == len(sliced) - 1: + cat.append(x[:, :, 1:, :]) + else: + cat.append(x[:, :, 1:-1, :]) + del x + x = torch.cat(cat, dim=2) + return x + + +def resblock_forward(_self, num_slices, input_tensor, temb): + assert _self.upsample is None and _self.downsample is None + assert _self.norm1.num_groups == _self.norm2.num_groups + assert temb is None + + # make sure norms are on cpu + org_device = input_tensor.device + cpu_device = torch.device("cpu") + _self.norm1.to(cpu_device) + _self.norm2.to(cpu_device) + + # すべてのテンソルをCPUに移動する + input_tensor = input_tensor.to(cpu_device) + hidden_states = input_tensor + + # どうもこれは結果が異なるようだ…… + # def sliced_norm1(norm, x): + # num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups + # sliced_tensor = torch.chunk(x, num_div, dim=1) + # sliced_weight = torch.chunk(norm.weight, num_div, dim=0) + # sliced_bias = torch.chunk(norm.bias, num_div, dim=0) + # print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape) + # normed_tensor = [] + # for i in range(num_div): + # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps) + # normed_tensor.append(n) + # del n + # x = torch.cat(normed_tensor, dim=1) + # return num_div, x + + # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない + hidden_states = _self.norm1(hidden_states) # run on cpu + + sliced = slice_h(hidden_states, num_slices) + del hidden_states + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + # 計算する部分だけGPUに移動する、以下同様 + x = x.to(org_device) + x = _self.nonlinearity(x) + x = _self.conv1(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + hidden_states = cat_h(sliced) + del sliced + + hidden_states = _self.norm2(hidden_states) # run on cpu + + sliced = slice_h(hidden_states, num_slices) + del hidden_states + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = _self.nonlinearity(x) + x = _self.dropout(x) + x = _self.conv2(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + hidden_states = cat_h(sliced) + del sliced + + # make shortcut + if _self.conv_shortcut is not None: + sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする + del input_tensor + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = _self.conv_shortcut(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + input_tensor = torch.cat(sliced, dim=2) + del sliced + + output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor + + output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する + return output_tensor + + +class SlicingEncoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + double_z=True, + num_slices=2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + # replace forward of ResBlocks + def wrapper(func, module, num_slices): + def forward(*args, **kwargs): + return func(module, num_slices, *args, **kwargs) + + return forward + + self.num_slices = num_slices + div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす + # print(f"initial divisor: {div}") + if div >= 2: + div = int(div) + for resnet in self.mid_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + # midblock doesn't have downsample + + for i, down_block in enumerate(self.down_blocks[::-1]): + if div >= 2: + div = int(div) + # print(f"down block: {i} divisor: {div}") + for resnet in down_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + if down_block.downsamplers is not None: + # print("has downsample") + for downsample in down_block.downsamplers: + downsample.forward = wrapper(self.downsample_forward, downsample, div * 2) + div *= 2 + + def forward(self, x): + sample = x + del x + + org_device = sample.device + cpu_device = torch.device("cpu") + + # sample = self.conv_in(sample) + sample = sample.to(cpu_device) + sliced = slice_h(sample, self.num_slices) + del sample + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = self.conv_in(x) + x = x.to(cpu_device) + sliced[i] = x + del x + + sample = cat_h(sliced) + del sliced + + sample = sample.to(org_device) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略 + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + def downsample_forward(self, _self, num_slices, hidden_states): + assert hidden_states.shape[1] == _self.channels + assert _self.use_conv and _self.padding == 0 + print("downsample forward", num_slices, hidden_states.shape) + + org_device = hidden_states.device + cpu_device = torch.device("cpu") + + hidden_states = hidden_states.to(cpu_device) + pad = (0, 1, 0, 1) + hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0) + + # slice with even number because of stride 2 + # strideが2なので偶数でスライスする + # slice with pad 1 both sides: to eliminate side effect of padding of conv2d + size = (hidden_states.shape[2] + num_slices - 1) // num_slices + size = size + 1 if size % 2 == 1 else size + + sliced = [] + for i in range(num_slices): + if i == 0: + sliced.append(hidden_states[:, :, : size + 1, :]) + else: + end = size * (i + 1) + 1 + if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor + end = hidden_states.shape[2] + sliced.append(hidden_states[:, :, size * i - 1 : end, :]) + if end >= hidden_states.shape[2]: + break + del hidden_states + + for i in range(len(sliced)): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = _self.conv(x) + x = x.to(cpu_device) + + # ここだけ雰囲気が違うのはCopilotのせい + if i == 0: + hidden_states = x + else: + hidden_states = torch.cat([hidden_states, x], dim=2) + + hidden_states = hidden_states.to(org_device) + # print("downsample forward done", hidden_states.shape) + return hidden_states + + +class SlicingDecoder(nn.Module): + def __init__( + self, + in_channels=3, + out_channels=3, + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(64,), + layers_per_block=2, + norm_num_groups=32, + act_fn="silu", + num_slices=2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attn_num_head_channels=None, + resnet_groups=norm_num_groups, + temb_channels=None, + ) + self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attn_num_head_channels=None, + temb_channels=None, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + # replace forward of ResBlocks + def wrapper(func, module, num_slices): + def forward(*args, **kwargs): + return func(module, num_slices, *args, **kwargs) + + return forward + + self.num_slices = num_slices + div = num_slices / (2 ** (len(self.up_blocks) - 1)) + print(f"initial divisor: {div}") + if div >= 2: + div = int(div) + for resnet in self.mid_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + # midblock doesn't have upsample + + for i, up_block in enumerate(self.up_blocks): + if div >= 2: + div = int(div) + # print(f"up block: {i} divisor: {div}") + for resnet in up_block.resnets: + resnet.forward = wrapper(resblock_forward, resnet, div) + if up_block.upsamplers is not None: + # print("has upsample") + for upsample in up_block.upsamplers: + upsample.forward = wrapper(self.upsample_forward, upsample, div * 2) + div *= 2 + + def forward(self, z): + sample = z + del z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for i, up_block in enumerate(self.up_blocks): + sample = up_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + # conv_out with slicing because of VRAM usage + # conv_outはとてもVRAM使うのでスライスして対応 + org_device = sample.device + cpu_device = torch.device("cpu") + sample = sample.to(cpu_device) + + sliced = slice_h(sample, self.num_slices) + del sample + for i in range(self.num_slices): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + x = self.conv_out(x) + x = x.to(cpu_device) + sliced[i] = x + sample = cat_h(sliced) + del sliced + + sample = sample.to(org_device) + return sample + + def upsample_forward(self, _self, num_slices, hidden_states, output_size=None): + assert hidden_states.shape[1] == _self.channels + assert _self.use_conv_transpose == False and _self.use_conv + + org_dtype = hidden_states.dtype + org_device = hidden_states.device + cpu_device = torch.device("cpu") + + hidden_states = hidden_states.to(cpu_device) + sliced = slice_h(hidden_states, num_slices) + del hidden_states + + for i in range(num_slices): + x = sliced[i] + sliced[i] = None + + x = x.to(org_device) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + # PyTorch 2で直らないかね…… + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + + x = _self.conv(x) + + # upsampleされてるのでpadは2になる + if i == 0: + x = x[:, :, :-2, :] + elif i == num_slices - 1: + x = x[:, :, 2:, :] + else: + x = x[:, :, 2:-2, :] + + x = x.to(cpu_device) + sliced[i] = x + del x + + hidden_states = torch.cat(sliced, dim=2) + # print("us hidden_states", hidden_states.shape) + del sliced + + hidden_states = hidden_states.to(org_device) + return hidden_states + + +class SlicingAutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + num_slices: int = 16, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = SlicingEncoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + num_slices=num_slices, + ) + + # pass init params to Decoder + self.decoder = SlicingDecoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + num_slices=num_slices, + ) + + self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.use_slicing = False + + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # これはバッチ方向のスライシング 紛らわしい + def enable_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) From ed5bfda3723b857fb64260b0bc44b77768de7339 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 May 2023 21:12:06 +0900 Subject: [PATCH 3/8] Fix controlnet input to rgb from bgr --- tools/original_control_net.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/original_control_net.py b/tools/original_control_net.py index 4484ce9c..582794de 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -62,7 +62,7 @@ def load_control_net(v2, unet, model): # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference") + print("ControlNet: loading difference:", is_difference) # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく # またTransfer Controlの元weightとなる @@ -123,7 +123,8 @@ def load_preprocess(prep_type: str): def preprocess_ctrl_net_hint_image(image): image = np.array(image).astype(np.float32) / 255.0 - image = image[:, :, ::-1].copy() # rgb to bgr + # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている + # image = image[:, :, ::-1].copy() # rgb to bgr image = image[None].transpose(0, 3, 1, 2) # nchw image = torch.from_numpy(image) return image # 0 to 1 From af08c56ce057cfc153e532281f4f9e58a74b98a7 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 May 2023 21:20:18 +0900 Subject: [PATCH 4/8] remove unnecessary newline --- networks/extract_lora_from_dylora.py | 2 +- networks/merge_lora.py | 4 ++-- networks/resize_lora.py | 2 +- networks/svd_merge_lora.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/networks/extract_lora_from_dylora.py b/networks/extract_lora_from_dylora.py index 5aa9403a..0abee983 100644 --- a/networks/extract_lora_from_dylora.py +++ b/networks/extract_lora_from_dylora.py @@ -94,7 +94,7 @@ def split(args): filename, ext = os.path.splitext(args.save_to) model_file_name = filename + f"-{new_rank:04d}{ext}" - print(f"\nsaving model to: {model_file_name}") + print(f"saving model to: {model_file_name}") save_to_file(model_file_name, state_dict, new_metadata) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index a7a0d83d..2fa8861b 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -193,12 +193,12 @@ def merge(args): merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype) - print(f"\nsaving SD model to: {args.save_to}") + print(f"saving SD model to: {args.save_to}") model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) - print(f"\nsaving model to: {args.save_to}") + print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 4f7499e8..7b740634 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -326,7 +326,7 @@ def resize(args): metadata["sshs_model_hash"] = model_hash metadata["sshs_legacy_hash"] = legacy_hash - print(f"\nsaving model to: {args.save_to}") + print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata) diff --git a/networks/svd_merge_lora.py b/networks/svd_merge_lora.py index 8cd389db..9d17efba 100644 --- a/networks/svd_merge_lora.py +++ b/networks/svd_merge_lora.py @@ -160,7 +160,7 @@ def merge(args): new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype) - print(f"\nsaving model to: {args.save_to}") + print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, save_dtype) From 2767a0f9f23e771d477682d325dd28e903786ff4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 May 2023 21:47:59 +0900 Subject: [PATCH 5/8] common block lr args processing in create --- networks/lora.py | 59 +++++++++++++++++++++++++++++++----------------- train_network.py | 33 +++++---------------------- 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 898ffce9..121a6281 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -73,7 +73,7 @@ class LoRAModule(torch.nn.Module): class LoRAInfModule(LoRAModule): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) - + self.org_module_ref = [org_module] # 後から参照できるように self.enabled = True @@ -319,6 +319,35 @@ class LoRAInfModule(LoRAModule): return out +def parse_block_lr_kwargs(nw_kwargs): + down_lr_weight = nw_kwargs.get("down_lr_weight", None) + mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) + up_lr_weight = nw_kwargs.get("up_lr_weight", None) + + # 以上のいずれにも設定がない場合は無効としてNoneを返す + if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: + return None, None, None + + # extract learning rate weight for each block + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + return down_lr_weight, mid_lr_weight, up_lr_weight + + def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default @@ -337,9 +366,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un # block dim/alpha/lr block_dims = kwargs.get("block_dims", None) - down_lr_weight = kwargs.get("down_lr_weight", None) - mid_lr_weight = kwargs.get("mid_lr_weight", None) - up_lr_weight = kwargs.get("up_lr_weight", None) + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: @@ -351,22 +378,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ) - # extract learning rate weight for each block - if down_lr_weight is not None: - # if some parameters are not set, use zero - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0)) - ) # remove block dim/alpha without learning rate block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( @@ -634,6 +645,12 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh network = LoRANetwork( text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class ) + + # block lr + down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + return network, weights_sd @@ -835,7 +852,7 @@ class LoRANetwork(torch.nn.Module): print(f"weights are merged") - # 層別学習率用に層ごとの学習率に対する倍率を定義する + # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない def set_block_lr_weight( self, up_lr_weight: List[float] = None, diff --git a/train_network.py b/train_network.py index bcfd657f..b5cdfea1 100644 --- a/train_network.py +++ b/train_network.py @@ -176,32 +176,10 @@ def train(args): net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') - if args.size_from_weights: - network, weights = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet) - if net_kwargs is not None: - down_lr_weight = net_kwargs.get("down_lr_weight", None) - mid_lr_weight = net_kwargs.get("mid_lr_weight", None) - up_lr_weight = net_kwargs.get("up_lr_weight", None) - if down_lr_weight is not None: - # if some parameters are not set, use zero - if "," in down_lr_weight: - down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] - - if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) - - if up_lr_weight is not None: - if "," in up_lr_weight: - up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - - down_lr_weight, mid_lr_weight, up_lr_weight = network_module.get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(net_kwargs.get("block_lr_zero_threshold", 0.0)) - ) - - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + if args.dim_from_weights: + network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs) else: - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return @@ -786,10 +764,11 @@ def setup_parser() -> argparse.ArgumentParser: "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" ) parser.add_argument( - "--size_from_weights", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + "--dim_from_weights", + action="store_true", + help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する", ) - return parser From 8d562ecf48d3cea64dda509a033423e17b91c2ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Thu, 11 May 2023 20:48:51 +0800 Subject: [PATCH 6/8] fix pynoise code bug (#489) * fix pynoise * Update custom_train_functions.py for default * Update custom_train_functions.py for note * Update custom_train_functions.py for default * Revert "Update custom_train_functions.py for default" This reverts commit ca79915d7396ddb57adbeb4b78bafb9a1a884b5c. * Update custom_train_functions.py for default * Revert "Update custom_train_functions.py for default" This reverts commit 483577e137b13933ff24b6ae254f82c0a8d9f1fe. * default value change --- library/custom_train_functions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2d387d15..0c527c35 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -346,14 +346,14 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.3): - b, c, w, h = noise.shape +def pyramid_noise_like(noise, device, iterations=6, discount=0.4): + b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): r = random.random() * 2 + 2 # Rather than always going 2x, - w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) - noise += u(torch.randn(b, c, w, h).to(device)) * discount**i - if w == 1 or h == 1: + wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i + if wn == 1 or hn == 1: break # Lowest resolution is 1x1 return noise / noise.std() # Scaled back to roughly unit variance From 7889a52f959ca8d7350b0f6951690a984a68f38d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 May 2023 22:00:41 +0900 Subject: [PATCH 7/8] add callback for step start --- train_network.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index b5cdfea1..b331e92c 100644 --- a/train_network.py +++ b/train_network.py @@ -525,10 +525,11 @@ def train(args): loss_total = 0.0 del train_dataset_group - # if hasattr(network, "on_step_start"): - # on_step_start = network.on_step_start - # else: - # on_step_start = lambda *args, **kwargs: None + # callback for step start + if hasattr(network, "on_step_start"): + on_step_start = network.on_step_start + else: + on_step_start = lambda *args, **kwargs: None # function for saving/removing def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): @@ -563,7 +564,7 @@ def train(args): for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(network): - # on_step_start(text_encoder, unet) + on_step_start(text_encoder, unet) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 47b6101465a3548e9509b122baba0a9c98b2d98d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 11 May 2023 22:17:32 +0900 Subject: [PATCH 8/8] update readme --- README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/README.md b/README.md index 87c0bf93..cd859064 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,28 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 11 May 2023, 2023/05/11 + +- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova! + - It is useful in combination with `resize_lora.py`. Please see the PR for details. +- Fixed a bug where the noise resolution was incorrect with Multires noise. [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) Thanks to sdbds! + - Please see the PR for details. +- The image generation scripts can now use img2img and highres fix at the same time. +- Fixed a bug where the hint image of ControlNet was incorrectly BGR instead of RGB in the image generation scripts. +- Added a feature to the image generation scripts to use the memory-efficient VAE. + - If you specify a number with the `--vae_slices` option, the memory-efficient VAE will be used. The maximum output size will be larger, but it will be slower. Please specify a value of about `16` or `32`. + - The implementation of the VAE is in `library/slicing_vae.py`. + +- `train_network.py`にdim(rank)を重みファイルから自動決定するオプション`--dim_from_weights`が追加されました。 [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) AI-Casanova氏に感謝します。 + - `resize_lora.py`と組み合わせると有用です。詳細はPRもご参照ください。 +- Multires noiseでノイズ解像度が正しくない不具合が修正されました。 [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) sdbds氏に感謝します。 + - 詳細は当該PRをご参照ください。 +- 生成スクリプトでimg2imgとhighres fixを同時に使用できるようにしました。 +- 生成スクリプトでControlNetのhint画像が誤ってBGRだったのをRGBに修正しました。 +- 生成スクリプトで省メモリ化VAEを使えるよう機能追加しました。 + - `--vae_slices`オプションに数値を指定すると、省メモリ化VAEを用います。出力可能な最大サイズが大きくなりますが、遅くなります。`16`または`32`程度の値を指定してください。 + - VAEの実装は`library/slicing_vae.py`にあります。 + ### 7 May 2023, 2023/05/07 - The documentation has been moved to the `docs` folder. If you have links, please change them.