mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add support for specifying rank for each layer in FLUX.1
This commit is contained in:
@@ -316,6 +316,44 @@ def create_network(
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
|
||||
img_attn_dim = kwargs.get("img_attn_dim", None)
|
||||
txt_attn_dim = kwargs.get("txt_attn_dim", None)
|
||||
img_mlp_dim = kwargs.get("img_mlp_dim", None)
|
||||
txt_mlp_dim = kwargs.get("txt_mlp_dim", None)
|
||||
img_mod_dim = kwargs.get("img_mod_dim", None)
|
||||
txt_mod_dim = kwargs.get("txt_mod_dim", None)
|
||||
single_dim = kwargs.get("single_dim", None) # SingleStreamBlock
|
||||
single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock
|
||||
if img_attn_dim is not None:
|
||||
img_attn_dim = int(img_attn_dim)
|
||||
if txt_attn_dim is not None:
|
||||
txt_attn_dim = int(txt_attn_dim)
|
||||
if img_mlp_dim is not None:
|
||||
img_mlp_dim = int(img_mlp_dim)
|
||||
if txt_mlp_dim is not None:
|
||||
txt_mlp_dim = int(txt_mlp_dim)
|
||||
if img_mod_dim is not None:
|
||||
img_mod_dim = int(img_mod_dim)
|
||||
if txt_mod_dim is not None:
|
||||
txt_mod_dim = int(txt_mod_dim)
|
||||
if single_dim is not None:
|
||||
single_dim = int(single_dim)
|
||||
if single_mod_dim is not None:
|
||||
single_mod_dim = int(single_mod_dim)
|
||||
type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim]
|
||||
if all([d is None for d in type_dims]):
|
||||
type_dims = None
|
||||
|
||||
# in_dims [img, time, vector, guidance, txt]
|
||||
in_dims = kwargs.get("in_dims", None)
|
||||
if in_dims is not None:
|
||||
in_dims = in_dims.strip()
|
||||
if in_dims.startswith("[") and in_dims.endswith("]"):
|
||||
in_dims = in_dims[1:-1]
|
||||
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
|
||||
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
@@ -339,6 +377,11 @@ def create_network(
|
||||
if train_t5xxl is not None:
|
||||
train_t5xxl = True if train_t5xxl == "True" else False
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", False)
|
||||
if verbose is not None:
|
||||
verbose = True if verbose == "True" else False
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -354,7 +397,9 @@ def create_network(
|
||||
train_blocks=train_blocks,
|
||||
split_qkv=split_qkv,
|
||||
train_t5xxl=train_t5xxl,
|
||||
varbose=True,
|
||||
type_dims=type_dims,
|
||||
in_dims=in_dims,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
@@ -462,7 +507,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
train_blocks: Optional[str] = None,
|
||||
split_qkv: bool = False,
|
||||
train_t5xxl: bool = False,
|
||||
varbose: Optional[bool] = False,
|
||||
type_dims: Optional[List[int]] = None,
|
||||
in_dims: Optional[List[int]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
@@ -478,12 +525,17 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.split_qkv = split_qkv
|
||||
self.train_t5xxl = train_t5xxl
|
||||
|
||||
self.type_dims = type_dims
|
||||
self.in_dims = in_dims
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
self.in_dims = [0] * 5 # create in_dims
|
||||
# verbose = True
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
@@ -502,7 +554,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str]
|
||||
is_flux: bool,
|
||||
text_encoder_idx: Optional[int],
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
) -> List[LoRAModule]:
|
||||
prefix = (
|
||||
self.LORA_PREFIX_FLUX
|
||||
@@ -513,16 +570,22 @@ class LoRANetwork(torch.nn.Module):
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None: # dirty hack for all modules
|
||||
module = root_module # search all modules
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if filter is not None and not filter in lora_name:
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
|
||||
@@ -534,8 +597,25 @@ class LoRANetwork(torch.nn.Module):
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = self.lora_dim
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if type_dims is not None:
|
||||
identifier = [
|
||||
("img_attn",),
|
||||
("txt_attn",),
|
||||
("img_mlp",),
|
||||
("txt_mlp",),
|
||||
("img_mod",),
|
||||
("txt_mod",),
|
||||
("single_blocks", "linear"),
|
||||
("modulation",),
|
||||
]
|
||||
for i, d in enumerate(type_dims):
|
||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||
dim = d
|
||||
break
|
||||
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
@@ -566,6 +646,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
split_dims=split_dims,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break # all modules are searched
|
||||
return loras, skipped
|
||||
|
||||
# create LoRA for text encoder
|
||||
@@ -594,10 +677,20 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
|
||||
|
||||
# img, time, vector, guidance, txt
|
||||
if self.in_dims:
|
||||
for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims):
|
||||
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
|
||||
self.unet_loras.extend(loras)
|
||||
|
||||
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
if verbose and len(skipped) > 0:
|
||||
logger.warning(
|
||||
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user