use original ControlNet instead of Diffusers

This commit is contained in:
Kohya S
2024-09-29 23:07:34 +09:00
parent e0c3630203
commit 8919b31145
5 changed files with 526 additions and 237 deletions

View File

@@ -30,7 +30,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging
@@ -1156,9 +1156,9 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
"""
_self = self.delegate
@@ -1209,6 +1209,8 @@ class InferSdxlUNet2DConditionModel:
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
if mid_add is not None:
h = h + mid_add
for module in _self.output_blocks:
# Deep Shrink
@@ -1217,7 +1219,11 @@ class InferSdxlUNet2DConditionModel:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
h = torch.cat([h, hs.pop()], dim=1)
resi = hs.pop()
if input_resi_add is not None:
resi = resi + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0