mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
feat: block swap for inference and initial impl for HunyuanImage LoRA (not working)
This commit is contained in:
@@ -713,6 +713,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
|
||||
|
||||
@classmethod
|
||||
def get_qkv_mlp_split_dims(cls) -> List[int]:
|
||||
return [3072] * 3 + [12288]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoders: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
@@ -842,7 +846,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
break
|
||||
|
||||
# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
|
||||
if dim is None and modules_dim is None:
|
||||
if dim is None and modules_dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
@@ -901,9 +905,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
split_dims = None
|
||||
if is_flux and split_qkv:
|
||||
if "double" in lora_name and "qkv" in lora_name:
|
||||
split_dims = [3072] * 3
|
||||
(split_dims,) = self.get_qkv_mlp_split_dims()[:3] # qkv only
|
||||
elif "single" in lora_name and "linear1" in lora_name:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
@@ -1036,9 +1040,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
# split qkv
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -1092,9 +1096,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
new_state_dict = {}
|
||||
for key in list(state_dict.keys()):
|
||||
if "double" in key and "qkv" in key:
|
||||
split_dims = [3072] * 3
|
||||
split_dims = self.get_qkv_mlp_split_dims()[:3] # qkv only
|
||||
elif "single" in key and "linear1" in key:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
split_dims = self.get_qkv_mlp_split_dims() # qkv + mlp
|
||||
else:
|
||||
new_state_dict[key] = state_dict[key]
|
||||
continue
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user