From 0eacadfa99c51103f8daf559b06c61a5ba8fcf96 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 30 Jul 2023 14:09:43 +0900 Subject: [PATCH] fix ControlNet not working --- tools/original_control_net.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tools/original_control_net.py b/tools/original_control_net.py index 347e8902..161bf322 100644 --- a/tools/original_control_net.py +++ b/tools/original_control_net.py @@ -4,8 +4,7 @@ import cv2 import torch from safetensors.torch import load_file -from diffusers import UNet2DConditionModel -from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from library.original_unet import UNet2DConditionModel, SampleOutput import library.model_util as model_util @@ -235,10 +234,6 @@ def unet_forward( print("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # 0. center input if necessary - if unet.config.center_input_sample: - sample = 2 * sample - 1.0 - # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -277,7 +272,7 @@ def unet_forward( # 3. down down_block_res_samples = (sample,) for downsample_block in unet.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + if downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, @@ -321,7 +316,7 @@ def unet_forward( if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + if upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, @@ -338,4 +333,4 @@ def unet_forward( sample = unet.conv_act(sample) sample = unet.conv_out(sample) - return UNet2DConditionOutput(sample=sample) + return SampleOutput(sample=sample)