mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
support Diffusers' based SDXL LoRA key for inference
This commit is contained in:
@@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
|||||||
return block_idx
|
return block_idx
|
||||||
|
|
||||||
|
|
||||||
|
def convert_diffusers_to_sai_if_needed(weights_sd):
|
||||||
|
# only supports U-Net LoRA modules
|
||||||
|
|
||||||
|
found_up_down_blocks = False
|
||||||
|
for k in list(weights_sd.keys()):
|
||||||
|
if "down_blocks" in k:
|
||||||
|
found_up_down_blocks = True
|
||||||
|
break
|
||||||
|
if "up_blocks" in k:
|
||||||
|
found_up_down_blocks = True
|
||||||
|
break
|
||||||
|
if not found_up_down_blocks:
|
||||||
|
return
|
||||||
|
|
||||||
|
from library.sdxl_model_util import make_unet_conversion_map
|
||||||
|
|
||||||
|
unet_conversion_map = make_unet_conversion_map()
|
||||||
|
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
||||||
|
|
||||||
|
# # add extra conversion
|
||||||
|
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
|
||||||
|
|
||||||
|
logger.info(f"Converting LoRA keys from Diffusers to SAI")
|
||||||
|
lora_unet_prefix = "lora_unet_"
|
||||||
|
for k in list(weights_sd.keys()):
|
||||||
|
if not k.startswith(lora_unet_prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
|
||||||
|
|
||||||
|
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
|
||||||
|
for hf_module_name, sd_module_name in unet_conversion_map.items():
|
||||||
|
if hf_module_name in unet_module_name:
|
||||||
|
new_key = (
|
||||||
|
lora_unet_prefix
|
||||||
|
+ unet_module_name.replace(hf_module_name, sd_module_name)
|
||||||
|
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
|
||||||
|
)
|
||||||
|
weights_sd[new_key] = weights_sd.pop(k)
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
logger.warning(f"Key {k} is not found in unet_conversion_map")
|
||||||
|
|
||||||
|
|
||||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||||
@@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
else:
|
else:
|
||||||
weights_sd = torch.load(file, map_location="cpu")
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
# if keys are Diffusers based, convert to SAI based
|
||||||
|
convert_diffusers_to_sai_if_needed(weights_sd)
|
||||||
|
|
||||||
# get dim/alpha mapping
|
# get dim/alpha mapping
|
||||||
modules_dim = {}
|
modules_dim = {}
|
||||||
modules_alpha = {}
|
modules_alpha = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user