From 2a4ae88f1823c95f5c342bf4b085872dfb874754 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Jul 2023 14:03:54 +0900 Subject: [PATCH] format by black --- tools/original_control_net.py | 464 ++++++++++++++++++---------------- 1 file changed, 242 insertions(+), 222 deletions(-) diff --git a/tools/original_control_net.py b/tools/original_control_net.py index 582794de..347e8902 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -11,169 +11,180 @@ import library.model_util as model_util class ControlNetInfo(NamedTuple): - unet: Any - net: Any - prep: Any - weight: float - ratio: float + unet: Any + net: Any + prep: Any + weight: float + ratio: float class ControlNet(torch.nn.Module): - def __init__(self) -> None: - super().__init__() + def __init__(self) -> None: + super().__init__() - # make control model - self.control_model = torch.nn.Module() + # make control model + self.control_model = torch.nn.Module() - dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] - zero_convs = torch.nn.ModuleList() - for i, dim in enumerate(dims): - sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) - zero_convs.append(sub_list) - self.control_model.add_module("zero_convs", zero_convs) + dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280] + zero_convs = torch.nn.ModuleList() + for i, dim in enumerate(dims): + sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)]) + zero_convs.append(sub_list) + self.control_model.add_module("zero_convs", zero_convs) - middle_block_out = torch.nn.Conv2d(1280, 1280, 1) - self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) + middle_block_out = torch.nn.Conv2d(1280, 1280, 1) + self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out])) - dims = [16, 16, 32, 32, 96, 96, 256, 320] - strides = [1, 1, 2, 1, 2, 1, 2, 1] - prev_dim = 3 - input_hint_block = torch.nn.Sequential() - for i, (dim, stride) in enumerate(zip(dims, strides)): - input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) - if i < len(dims) - 1: - input_hint_block.append(torch.nn.SiLU()) - prev_dim = dim - self.control_model.add_module("input_hint_block", input_hint_block) + dims = [16, 16, 32, 32, 96, 96, 256, 320] + strides = [1, 1, 2, 1, 2, 1, 2, 1] + prev_dim = 3 + input_hint_block = torch.nn.Sequential() + for i, (dim, stride) in enumerate(zip(dims, strides)): + input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1)) + if i < len(dims) - 1: + input_hint_block.append(torch.nn.SiLU()) + prev_dim = dim + self.control_model.add_module("input_hint_block", input_hint_block) def load_control_net(v2, unet, model): - device = unet.device + device = unet.device - # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む - # state dictを読み込む - print(f"ControlNet: loading control SD model : {model}") + # control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む + # state dictを読み込む + print(f"ControlNet: loading control SD model : {model}") - if model_util.is_safetensors(model): - ctrl_sd_sd = load_file(model) - else: - ctrl_sd_sd = torch.load(model, map_location='cpu') - ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) + if model_util.is_safetensors(model): + ctrl_sd_sd = load_file(model) + else: + ctrl_sd_sd = torch.load(model, map_location="cpu") + ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) - # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む - is_difference = "difference" in ctrl_sd_sd - print("ControlNet: loading difference:", is_difference) + # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む + is_difference = "difference" in ctrl_sd_sd + print("ControlNet: loading difference:", is_difference) - # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく - # またTransfer Controlの元weightとなる - ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) + # ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく + # またTransfer Controlの元weightとなる + ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict()) - # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける - for key in list(ctrl_unet_sd_sd.keys()): - ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() + # 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける + for key in list(ctrl_unet_sd_sd.keys()): + ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone() - zero_conv_sd = {} - for key in list(ctrl_sd_sd.keys()): - if key.startswith("control_"): - unet_key = "model.diffusion_" + key[len("control_"):] - if unet_key not in ctrl_unet_sd_sd: # zero conv - zero_conv_sd[key] = ctrl_sd_sd[key] - continue - if is_difference: # Transfer Control - ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) - else: - ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) + zero_conv_sd = {} + for key in list(ctrl_sd_sd.keys()): + if key.startswith("control_"): + unet_key = "model.diffusion_" + key[len("control_") :] + if unet_key not in ctrl_unet_sd_sd: # zero conv + zero_conv_sd[key] = ctrl_sd_sd[key] + continue + if is_difference: # Transfer Control + ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype) + else: + ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype) - unet_config = model_util.create_unet_diffusers_config(v2) - ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict + unet_config = model_util.create_unet_diffusers_config(v2) + ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict - # ControlNetのU-Netを作成する - ctrl_unet = UNet2DConditionModel(**unet_config) - info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) - print("ControlNet: loading Control U-Net:", info) + # ControlNetのU-Netを作成する + ctrl_unet = UNet2DConditionModel(**unet_config) + info = ctrl_unet.load_state_dict(ctrl_unet_du_sd) + print("ControlNet: loading Control U-Net:", info) - # U-Net以外のControlNetを作成する - # TODO support middle only - ctrl_net = ControlNet() - info = ctrl_net.load_state_dict(zero_conv_sd) - print("ControlNet: loading ControlNet:", info) + # U-Net以外のControlNetを作成する + # TODO support middle only + ctrl_net = ControlNet() + info = ctrl_net.load_state_dict(zero_conv_sd) + print("ControlNet: loading ControlNet:", info) - ctrl_unet.to(unet.device, dtype=unet.dtype) - ctrl_net.to(unet.device, dtype=unet.dtype) - return ctrl_unet, ctrl_net + ctrl_unet.to(unet.device, dtype=unet.dtype) + ctrl_net.to(unet.device, dtype=unet.dtype) + return ctrl_unet, ctrl_net def load_preprocess(prep_type: str): - if prep_type is None or prep_type.lower() == "none": + if prep_type is None or prep_type.lower() == "none": + return None + + if prep_type.startswith("canny"): + args = prep_type.split("_") + th1 = int(args[1]) if len(args) >= 2 else 63 + th2 = int(args[2]) if len(args) >= 3 else 191 + + def canny(img): + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + return cv2.Canny(img, th1, th2) + + return canny + + print("Unsupported prep type:", prep_type) return None - if prep_type.startswith("canny"): - args = prep_type.split("_") - th1 = int(args[1]) if len(args) >= 2 else 63 - th2 = int(args[2]) if len(args) >= 3 else 191 - - def canny(img): - img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - return cv2.Canny(img, th1, th2) - return canny - - print("Unsupported prep type:", prep_type) - return None - def preprocess_ctrl_net_hint_image(image): - image = np.array(image).astype(np.float32) / 255.0 - # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている - # image = image[:, :, ::-1].copy() # rgb to bgr - image = image[None].transpose(0, 3, 1, 2) # nchw - image = torch.from_numpy(image) - return image # 0 to 1 + image = np.array(image).astype(np.float32) / 255.0 + # ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている + # image = image[:, :, ::-1].copy() # rgb to bgr + image = image[None].transpose(0, 3, 1, 2) # nchw + image = torch.from_numpy(image) + return image # 0 to 1 def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints): - guided_hints = [] - for i, cnet_info in enumerate(control_nets): - # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること - b_hints = [] - if len(hints) == 1: # すべて同じ画像をhintとして使う - hint = hints[0] - if cnet_info.prep is not None: - hint = cnet_info.prep(hint) - hint = preprocess_ctrl_net_hint_image(hint) - b_hints = [hint for _ in range(b_size)] - else: - for bi in range(b_size): - hint = hints[(bi * len(control_nets) + i) % len(hints)] - if cnet_info.prep is not None: - hint = cnet_info.prep(hint) - hint = preprocess_ctrl_net_hint_image(hint) - b_hints.append(hint) - b_hints = torch.cat(b_hints, dim=0) - b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) + guided_hints = [] + for i, cnet_info in enumerate(control_nets): + # hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること + b_hints = [] + if len(hints) == 1: # すべて同じ画像をhintとして使う + hint = hints[0] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints = [hint for _ in range(b_size)] + else: + for bi in range(b_size): + hint = hints[(bi * len(control_nets) + i) % len(hints)] + if cnet_info.prep is not None: + hint = cnet_info.prep(hint) + hint = preprocess_ctrl_net_hint_image(hint) + b_hints.append(hint) + b_hints = torch.cat(b_hints, dim=0) + b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype) - guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) - guided_hints.append(guided_hint) - return guided_hints + guided_hint = cnet_info.net.control_model.input_hint_block(b_hints) + guided_hints.append(guided_hint) + return guided_hints -def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): - # ControlNet - # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する - cnet_cnt = len(control_nets) - cnet_idx = step % cnet_cnt - cnet_info = control_nets[cnet_idx] +def call_unet_and_control_net( + step, + num_latent_input, + original_unet, + control_nets: List[ControlNetInfo], + guided_hints, + current_ratio, + sample, + timestep, + encoder_hidden_states, +): + # ControlNet + # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する + cnet_cnt = len(control_nets) + cnet_idx = step % cnet_cnt + cnet_info = control_nets[cnet_idx] - # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) - if cnet_info.ratio < current_ratio: - return original_unet(sample, timestep, encoder_hidden_states) + # print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio) + if cnet_info.ratio < current_ratio: + return original_unet(sample, timestep, encoder_hidden_states) - guided_hint = guided_hints[cnet_idx] - guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) - outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) - outs = [o * cnet_info.weight for o in outs] + guided_hint = guided_hints[cnet_idx] + guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1)) + outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states) + outs = [o * cnet_info.weight for o in outs] - # U-Net - return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) + # U-Net + return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states) """ @@ -204,118 +215,127 @@ def call_unet_and_control_net(step, num_latent_input, original_unet, control_net """ -def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): - # copy from UNet2DConditionModel - default_overall_up_factor = 2**unet.num_upsamplers +def unet_forward( + is_control_net, + control_net: ControlNet, + unet: UNet2DConditionModel, + guided_hint, + ctrl_outs, + sample, + timestep, + encoder_hidden_states, +): + # copy from UNet2DConditionModel + default_overall_up_factor = 2**unet.num_upsamplers - forward_upsample_size = False - upsample_size = None + forward_upsample_size = False + upsample_size = None - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - print("Forward upsample size to force interpolation output size.") - forward_upsample_size = True + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + print("Forward upsample size to force interpolation output size.") + forward_upsample_size = True - # 0. center input if necessary - if unet.config.center_input_sample: - sample = 2 * sample - 1.0 + # 0. center input if necessary + if unet.config.center_input_sample: + sample = 2 * sample - 1.0 - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) - t_emb = unet.time_proj(timesteps) + t_emb = unet.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=unet.dtype) - emb = unet.time_embedding(t_emb) + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=unet.dtype) + emb = unet.time_embedding(t_emb) - outs = [] # output of ControlNet - zc_idx = 0 + outs = [] # output of ControlNet + zc_idx = 0 - # 2. pre-process - sample = unet.conv_in(sample) - if is_control_net: - sample += guided_hint - outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) - zc_idx += 1 - - # 3. down - down_block_res_samples = (sample,) - for downsample_block in unet.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + # 2. pre-process + sample = unet.conv_in(sample) if is_control_net: - for rs in res_samples: - outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + sample += guided_hint + outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states)) zc_idx += 1 - down_block_res_samples += res_samples + # 3. down + down_block_res_samples = (sample,) + for downsample_block in unet.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_control_net: + for rs in res_samples: + outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states)) + zc_idx += 1 - # 4. mid - sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) - if is_control_net: - outs.append(control_net.control_model.middle_block_out[0](sample)) - return outs + down_block_res_samples += res_samples - if not is_control_net: - sample += ctrl_outs.pop() + # 4. mid + sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + if is_control_net: + outs.append(control_net.control_model.middle_block_out[0](sample)) + return outs - # 5. up - for i, upsample_block in enumerate(unet.up_blocks): - is_final_block = i == len(unet.up_blocks) - 1 + if not is_control_net: + sample += ctrl_outs.pop() - res_samples = down_block_res_samples[-len(upsample_block.resnets):] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # 5. up + for i, upsample_block in enumerate(unet.up_blocks): + is_final_block = i == len(unet.up_blocks) - 1 - if not is_control_net and len(ctrl_outs) > 0: - res_samples = list(res_samples) - apply_ctrl_outs = ctrl_outs[-len(res_samples):] - ctrl_outs = ctrl_outs[:-len(res_samples)] - for j in range(len(res_samples)): - res_samples[j] = res_samples[j] + apply_ctrl_outs[j] - res_samples = tuple(res_samples) + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] + if not is_control_net and len(ctrl_outs) > 0: + res_samples = list(res_samples) + apply_ctrl_outs = ctrl_outs[-len(res_samples) :] + ctrl_outs = ctrl_outs[: -len(res_samples)] + for j in range(len(res_samples)): + res_samples[j] = res_samples[j] + apply_ctrl_outs[j] + res_samples = tuple(res_samples) - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - upsample_size=upsample_size, - ) - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = unet.conv_norm_out(sample) - sample = unet.conv_act(sample) - sample = unet.conv_out(sample) + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] - return UNet2DConditionOutput(sample=sample) + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = unet.conv_norm_out(sample) + sample = unet.conv_act(sample) + sample = unet.conv_out(sample) + + return UNet2DConditionOutput(sample=sample)