diff --git a/networks/lora.py b/networks/lora.py index 00d21b0e..79dc6ec0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -757,6 +757,9 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file, safe_open @@ -792,7 +795,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh ) # block lr - block_lr_weight = parse_block_lr_kwargs(kwargs) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) if block_lr_weight is not None: network.set_block_lr_weight(block_lr_weight)