feat: block swap for inference and initial impl for HunyuanImage LoRA (not working)

This commit is contained in:
Kohya S
2025-09-11 22:15:22 +09:00
parent 5149be5a87
commit 7f983c558d
16 changed files with 1363 additions and 1303 deletions

View File

@@ -713,6 +713,10 @@ class LoRANetwork(torch.nn.Module):
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
@classmethod
def get_qkv_mlp_split_dims(cls) -> List[int]:
return [3072] * 3 + [12288]
def __init__(
self,
text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
@@ -842,7 +846,7 @@ class LoRANetwork(torch.nn.Module):
break
# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
if dim is None and modules_dim is None:
if dim is None and modules_dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
@@ -901,9 +905,9 @@ class LoRANetwork(torch.nn.Module):
split_dims = None
if is_flux and split_qkv:
if "double" in lora_name and "qkv" in lora_name:
split_dims = [3072] * 3
(split_dims,) = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in lora_name and "linear1" in lora_name:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
lora = module_class(
lora_name,
@@ -1036,9 +1040,9 @@ class LoRANetwork(torch.nn.Module):
# split qkv
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
else:
continue
@@ -1092,9 +1096,9 @@ class LoRANetwork(torch.nn.Module):
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
else:
new_state_dict[key] = state_dict[key]
continue

File diff suppressed because it is too large Load Diff