mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix ControlNet not working
This commit is contained in:
@@ -4,8 +4,7 @@ import cv2
|
|||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from diffusers import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel, SampleOutput
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
|
||||||
@@ -235,10 +234,6 @@ def unet_forward(
|
|||||||
print("Forward upsample size to force interpolation output size.")
|
print("Forward upsample size to force interpolation output size.")
|
||||||
forward_upsample_size = True
|
forward_upsample_size = True
|
||||||
|
|
||||||
# 0. center input if necessary
|
|
||||||
if unet.config.center_input_sample:
|
|
||||||
sample = 2 * sample - 1.0
|
|
||||||
|
|
||||||
# 1. time
|
# 1. time
|
||||||
timesteps = timestep
|
timesteps = timestep
|
||||||
if not torch.is_tensor(timesteps):
|
if not torch.is_tensor(timesteps):
|
||||||
@@ -277,7 +272,7 @@ def unet_forward(
|
|||||||
# 3. down
|
# 3. down
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
for downsample_block in unet.down_blocks:
|
for downsample_block in unet.down_blocks:
|
||||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
if downsample_block.has_cross_attention:
|
||||||
sample, res_samples = downsample_block(
|
sample, res_samples = downsample_block(
|
||||||
hidden_states=sample,
|
hidden_states=sample,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
@@ -321,7 +316,7 @@ def unet_forward(
|
|||||||
if not is_final_block and forward_upsample_size:
|
if not is_final_block and forward_upsample_size:
|
||||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
if upsample_block.has_cross_attention:
|
||||||
sample = upsample_block(
|
sample = upsample_block(
|
||||||
hidden_states=sample,
|
hidden_states=sample,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
@@ -338,4 +333,4 @@ def unet_forward(
|
|||||||
sample = unet.conv_act(sample)
|
sample = unet.conv_act(sample)
|
||||||
sample = unet.conv_out(sample)
|
sample = unet.conv_out(sample)
|
||||||
|
|
||||||
return UNet2DConditionOutput(sample=sample)
|
return SampleOutput(sample=sample)
|
||||||
|
|||||||
Reference in New Issue
Block a user