mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix controlnet to work with gradual latent
This commit is contained in:
@@ -179,8 +179,21 @@ def call_unet_and_control_net(
|
|||||||
return original_unet(sample, timestep, encoder_hidden_states)
|
return original_unet(sample, timestep, encoder_hidden_states)
|
||||||
|
|
||||||
guided_hint = guided_hints[cnet_idx]
|
guided_hint = guided_hints[cnet_idx]
|
||||||
|
|
||||||
|
# gradual latent support: match the size of guided_hint to the size of sample
|
||||||
|
if guided_hint.shape[-2:] != sample.shape[-2:]:
|
||||||
|
# print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}")
|
||||||
|
org_dtype = guided_hint.dtype
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
guided_hint = guided_hint.to(torch.float32)
|
||||||
|
guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic")
|
||||||
|
if org_dtype == torch.bfloat16:
|
||||||
|
guided_hint = guided_hint.to(org_dtype)
|
||||||
|
|
||||||
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
||||||
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
|
outs = unet_forward(
|
||||||
|
True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net
|
||||||
|
)
|
||||||
outs = [o * cnet_info.weight for o in outs]
|
outs = [o * cnet_info.weight for o in outs]
|
||||||
|
|
||||||
# U-Net
|
# U-Net
|
||||||
|
|||||||
Reference in New Issue
Block a user