mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
format by black
This commit is contained in:
@@ -57,7 +57,7 @@ def load_control_net(v2, unet, model):
|
|||||||
if model_util.is_safetensors(model):
|
if model_util.is_safetensors(model):
|
||||||
ctrl_sd_sd = load_file(model)
|
ctrl_sd_sd = load_file(model)
|
||||||
else:
|
else:
|
||||||
ctrl_sd_sd = torch.load(model, map_location='cpu')
|
ctrl_sd_sd = torch.load(model, map_location="cpu")
|
||||||
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
||||||
|
|
||||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||||
@@ -115,6 +115,7 @@ def load_preprocess(prep_type: str):
|
|||||||
def canny(img):
|
def canny(img):
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||||
return cv2.Canny(img, th1, th2)
|
return cv2.Canny(img, th1, th2)
|
||||||
|
|
||||||
return canny
|
return canny
|
||||||
|
|
||||||
print("Unsupported prep type:", prep_type)
|
print("Unsupported prep type:", prep_type)
|
||||||
@@ -156,7 +157,17 @@ def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_siz
|
|||||||
return guided_hints
|
return guided_hints
|
||||||
|
|
||||||
|
|
||||||
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
|
def call_unet_and_control_net(
|
||||||
|
step,
|
||||||
|
num_latent_input,
|
||||||
|
original_unet,
|
||||||
|
control_nets: List[ControlNetInfo],
|
||||||
|
guided_hints,
|
||||||
|
current_ratio,
|
||||||
|
sample,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
):
|
||||||
# ControlNet
|
# ControlNet
|
||||||
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
||||||
cnet_cnt = len(control_nets)
|
cnet_cnt = len(control_nets)
|
||||||
@@ -204,7 +215,16 @@ def call_unet_and_control_net(step, num_latent_input, original_unet, control_net
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
|
def unet_forward(
|
||||||
|
is_control_net,
|
||||||
|
control_net: ControlNet,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
guided_hint,
|
||||||
|
ctrl_outs,
|
||||||
|
sample,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
):
|
||||||
# copy from UNet2DConditionModel
|
# copy from UNet2DConditionModel
|
||||||
default_overall_up_factor = 2**unet.num_upsamplers
|
default_overall_up_factor = 2**unet.num_upsamplers
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user