format by black

This commit is contained in:
Kohya S
2023-07-30 14:03:54 +09:00
parent a296654c1b
commit 2a4ae88f18

View File

@@ -57,7 +57,7 @@ def load_control_net(v2, unet, model):
if model_util.is_safetensors(model): if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model) ctrl_sd_sd = load_file(model)
else: else:
ctrl_sd_sd = torch.load(model, map_location='cpu') ctrl_sd_sd = torch.load(model, map_location="cpu")
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd) ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む # 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
@@ -115,6 +115,7 @@ def load_preprocess(prep_type: str):
def canny(img): def canny(img):
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return cv2.Canny(img, th1, th2) return cv2.Canny(img, th1, th2)
return canny return canny
print("Unsupported prep type:", prep_type) print("Unsupported prep type:", prep_type)
@@ -156,7 +157,17 @@ def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_siz
return guided_hints return guided_hints
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states): def call_unet_and_control_net(
step,
num_latent_input,
original_unet,
control_nets: List[ControlNetInfo],
guided_hints,
current_ratio,
sample,
timestep,
encoder_hidden_states,
):
# ControlNet # ControlNet
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する # 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
cnet_cnt = len(control_nets) cnet_cnt = len(control_nets)
@@ -204,7 +215,16 @@ def call_unet_and_control_net(step, num_latent_input, original_unet, control_net
""" """
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states): def unet_forward(
is_control_net,
control_net: ControlNet,
unet: UNet2DConditionModel,
guided_hint,
ctrl_outs,
sample,
timestep,
encoder_hidden_states,
):
# copy from UNet2DConditionModel # copy from UNet2DConditionModel
default_overall_up_factor = 2**unet.num_upsamplers default_overall_up_factor = 2**unet.num_upsamplers