mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support for controlnet in sample output
This commit is contained in:
@@ -731,8 +731,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
def controlnet_conversion_map():
|
||||
unet_conversion_map = [
|
||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||
@@ -792,6 +791,12 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
sd_prefix = f"zero_convs.{i}.0."
|
||||
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
||||
|
||||
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
|
||||
|
||||
|
||||
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
||||
|
||||
mapping = {k: k for k in controlnet_state_dict.keys()}
|
||||
for sd_name, diffusers_name in unet_conversion_map:
|
||||
mapping[diffusers_name] = sd_name
|
||||
@@ -807,6 +812,23 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
||||
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
||||
|
||||
mapping = {k: k for k in controlnet_state_dict.keys()}
|
||||
for sd_name, diffusers_name in unet_conversion_map:
|
||||
mapping[sd_name] = diffusers_name
|
||||
for k, v in mapping.items():
|
||||
for sd_part, diffusers_part in unet_conversion_map_layer:
|
||||
v = v.replace(sd_part, diffusers_part)
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "resnets" in v:
|
||||
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
||||
v = v.replace(sd_part, diffusers_part)
|
||||
mapping[k] = v
|
||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||
return new_state_dict
|
||||
|
||||
# ================#
|
||||
# VAE Conversion #
|
||||
@@ -928,7 +950,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
|
||||
Reference in New Issue
Block a user