Add support for specifying rank for each layer in FLUX.1

This commit is contained in:
Kohya S
2024-09-14 22:17:52 +09:00
parent 2d8ee3c280
commit c9ff4de905
2 changed files with 161 additions and 7 deletions

View File

@@ -316,6 +316,44 @@ def create_network(
else:
conv_alpha = float(conv_alpha)
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
img_attn_dim = kwargs.get("img_attn_dim", None)
txt_attn_dim = kwargs.get("txt_attn_dim", None)
img_mlp_dim = kwargs.get("img_mlp_dim", None)
txt_mlp_dim = kwargs.get("txt_mlp_dim", None)
img_mod_dim = kwargs.get("img_mod_dim", None)
txt_mod_dim = kwargs.get("txt_mod_dim", None)
single_dim = kwargs.get("single_dim", None) # SingleStreamBlock
single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock
if img_attn_dim is not None:
img_attn_dim = int(img_attn_dim)
if txt_attn_dim is not None:
txt_attn_dim = int(txt_attn_dim)
if img_mlp_dim is not None:
img_mlp_dim = int(img_mlp_dim)
if txt_mlp_dim is not None:
txt_mlp_dim = int(txt_mlp_dim)
if img_mod_dim is not None:
img_mod_dim = int(img_mod_dim)
if txt_mod_dim is not None:
txt_mod_dim = int(txt_mod_dim)
if single_dim is not None:
single_dim = int(single_dim)
if single_mod_dim is not None:
single_mod_dim = int(single_mod_dim)
type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim]
if all([d is None for d in type_dims]):
type_dims = None
# in_dims [img, time, vector, guidance, txt]
in_dims = kwargs.get("in_dims", None)
if in_dims is not None:
in_dims = in_dims.strip()
if in_dims.startswith("[") and in_dims.endswith("]"):
in_dims = in_dims[1:-1]
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
@@ -339,6 +377,11 @@ def create_network(
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
@@ -354,7 +397,9 @@ def create_network(
train_blocks=train_blocks,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
varbose=True,
type_dims=type_dims,
in_dims=in_dims,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
@@ -462,7 +507,9 @@ class LoRANetwork(torch.nn.Module):
train_blocks: Optional[str] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
varbose: Optional[bool] = False,
type_dims: Optional[List[int]] = None,
in_dims: Optional[List[int]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
@@ -478,12 +525,17 @@ class LoRANetwork(torch.nn.Module):
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.type_dims = type_dims
self.in_dims = in_dims
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.in_dims = [0] * 5 # create in_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
@@ -502,7 +554,12 @@ class LoRANetwork(torch.nn.Module):
# create module instances
def create_modules(
is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str]
is_flux: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_FLUX
@@ -513,16 +570,22 @@ class LoRANetwork(torch.nn.Module):
loras = []
skipped = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # dirty hack for all modules
module = root_module # search all modules
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
if filter is not None and not filter in lora_name:
continue
dim = None
alpha = None
@@ -534,8 +597,25 @@ class LoRANetwork(torch.nn.Module):
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if type_dims is not None:
identifier = [
("img_attn",),
("txt_attn",),
("img_mlp",),
("txt_mlp",),
("img_mod",),
("txt_mod",),
("single_blocks", "linear"),
("modulation",),
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d
break
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
@@ -566,6 +646,9 @@ class LoRANetwork(torch.nn.Module):
split_dims=split_dims,
)
loras.append(lora)
if target_replace_modules is None:
break # all modules are searched
return loras, skipped
# create LoRA for text encoder
@@ -594,10 +677,20 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
# img, time, vector, guidance, txt
if self.in_dims:
for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims):
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)