mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
format by black
This commit is contained in:
@@ -57,7 +57,7 @@ def load_control_net(v2, unet, model):
|
|||||||
if model_util.is_safetensors(model):
|
if model_util.is_safetensors(model):
|
||||||
ctrl_sd_sd = load_file(model)
|
ctrl_sd_sd = load_file(model)
|
||||||
else:
|
else:
|
||||||
ctrl_sd_sd = torch.load(model, map_location='cpu')
|
ctrl_sd_sd = torch.load(model, map_location="cpu")
|
||||||
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
||||||
|
|
||||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||||
@@ -75,7 +75,7 @@ def load_control_net(v2, unet, model):
|
|||||||
zero_conv_sd = {}
|
zero_conv_sd = {}
|
||||||
for key in list(ctrl_sd_sd.keys()):
|
for key in list(ctrl_sd_sd.keys()):
|
||||||
if key.startswith("control_"):
|
if key.startswith("control_"):
|
||||||
unet_key = "model.diffusion_" + key[len("control_"):]
|
unet_key = "model.diffusion_" + key[len("control_") :]
|
||||||
if unet_key not in ctrl_unet_sd_sd: # zero conv
|
if unet_key not in ctrl_unet_sd_sd: # zero conv
|
||||||
zero_conv_sd[key] = ctrl_sd_sd[key]
|
zero_conv_sd[key] = ctrl_sd_sd[key]
|
||||||
continue
|
continue
|
||||||
@@ -115,6 +115,7 @@ def load_preprocess(prep_type: str):
|
|||||||
def canny(img):
|
def canny(img):
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||||
return cv2.Canny(img, th1, th2)
|
return cv2.Canny(img, th1, th2)
|
||||||
|
|
||||||
return canny
|
return canny
|
||||||
|
|
||||||
print("Unsupported prep type:", prep_type)
|
print("Unsupported prep type:", prep_type)
|
||||||
@@ -156,7 +157,17 @@ def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_siz
|
|||||||
return guided_hints
|
return guided_hints
|
||||||
|
|
||||||
|
|
||||||
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
|
def call_unet_and_control_net(
|
||||||
|
step,
|
||||||
|
num_latent_input,
|
||||||
|
original_unet,
|
||||||
|
control_nets: List[ControlNetInfo],
|
||||||
|
guided_hints,
|
||||||
|
current_ratio,
|
||||||
|
sample,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
):
|
||||||
# ControlNet
|
# ControlNet
|
||||||
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
||||||
cnet_cnt = len(control_nets)
|
cnet_cnt = len(control_nets)
|
||||||
@@ -204,7 +215,16 @@ def call_unet_and_control_net(step, num_latent_input, original_unet, control_net
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
|
def unet_forward(
|
||||||
|
is_control_net,
|
||||||
|
control_net: ControlNet,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
guided_hint,
|
||||||
|
ctrl_outs,
|
||||||
|
sample,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
):
|
||||||
# copy from UNet2DConditionModel
|
# copy from UNet2DConditionModel
|
||||||
default_overall_up_factor = 2**unet.num_upsamplers
|
default_overall_up_factor = 2**unet.num_upsamplers
|
||||||
|
|
||||||
@@ -285,13 +305,13 @@ def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionM
|
|||||||
for i, upsample_block in enumerate(unet.up_blocks):
|
for i, upsample_block in enumerate(unet.up_blocks):
|
||||||
is_final_block = i == len(unet.up_blocks) - 1
|
is_final_block = i == len(unet.up_blocks) - 1
|
||||||
|
|
||||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
if not is_control_net and len(ctrl_outs) > 0:
|
if not is_control_net and len(ctrl_outs) > 0:
|
||||||
res_samples = list(res_samples)
|
res_samples = list(res_samples)
|
||||||
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
|
apply_ctrl_outs = ctrl_outs[-len(res_samples) :]
|
||||||
ctrl_outs = ctrl_outs[:-len(res_samples)]
|
ctrl_outs = ctrl_outs[: -len(res_samples)]
|
||||||
for j in range(len(res_samples)):
|
for j in range(len(res_samples)):
|
||||||
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
||||||
res_samples = tuple(res_samples)
|
res_samples = tuple(res_samples)
|
||||||
|
|||||||
Reference in New Issue
Block a user