mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add support for specifying rank for each layer in FLUX.1
This commit is contained in:
61
README.md
61
README.md
@@ -11,6 +11,10 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 14, 2024:
|
||||||
|
- You can now specify the rank for each layer in FLUX.1. See [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1) for details.
|
||||||
|
- OFT is now supported with FLUX.1. See [FLUX.1 OFT training](#flux1-oft-training) for details.
|
||||||
|
|
||||||
Sep 11, 2024:
|
Sep 11, 2024:
|
||||||
Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev!
|
Logging to wandb is improved. See PR [#1576](https://github.com/kohya-ss/sd-scripts/pull/1576) for details. Thanks to p1atdev!
|
||||||
|
|
||||||
@@ -46,6 +50,7 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `
|
|||||||
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
|
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
|
||||||
- [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model)
|
- [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model)
|
||||||
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
|
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
|
||||||
|
- [FLUX.1 OFT training](#flux1-oft-training)
|
||||||
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
|
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
|
||||||
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
|
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
|
||||||
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
|
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
|
||||||
@@ -191,6 +196,62 @@ In the implementation of Black Forest Labs' model, the projection layers of q/k/
|
|||||||
|
|
||||||
The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.
|
The compatibility of the saved model (state dict) is ensured by concatenating the weights of multiple LoRAs. However, since there are zero weights in some parts, the model size will be large.
|
||||||
|
|
||||||
|
#### Specify rank for each layer in FLUX.1
|
||||||
|
|
||||||
|
You can specify the rank for each layer in FLUX.1 by specifying the following network_args. If you specify `0`, LoRA will not be applied to that layer.
|
||||||
|
|
||||||
|
When network_args is not specified, the default value (`network_dim`) is applied, same as before.
|
||||||
|
|
||||||
|
|network_args|target layer|
|
||||||
|
|---|---|
|
||||||
|
|img_attn_dim|img_attn in DoubleStreamBlock|
|
||||||
|
|txt_attn_dim|txt_attn in DoubleStreamBlock|
|
||||||
|
|img_mlp_dim|img_mlp in DoubleStreamBlock|
|
||||||
|
|txt_mlp_dim|txt_mlp in DoubleStreamBlock|
|
||||||
|
|img_mod_dim|img_mod in DoubleStreamBlock|
|
||||||
|
|txt_mod_dim|txt_mod in DoubleStreamBlock|
|
||||||
|
|single_dim|linear1 and linear2 in SingleStreamBlock|
|
||||||
|
|single_mod_dim|modulation in SingleStreamBlock|
|
||||||
|
|
||||||
|
example:
|
||||||
|
```
|
||||||
|
--network_args "img_attn_dim=4" "img_mlp_dim=8" "txt_attn_dim=2" "txt_mlp_dim=2"
|
||||||
|
"img_mod_dim=2" "txt_mod_dim=2" "single_dim=4" "single_mod_dim=2"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can apply LoRA to the conditioning layers of Flux by specifying `in_dims` in network_args. When specifying, be sure to specify 5 numbers in `[]` as a comma-separated list.
|
||||||
|
|
||||||
|
example:
|
||||||
|
```
|
||||||
|
--network_args "in_dims=[4,2,2,2,4]"
|
||||||
|
```
|
||||||
|
|
||||||
|
Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt_in`. The above example applies LoRA to all conditioning layers, with rank 4 for `img_in`, 2 for `time_in`, `vector_in`, `guidance_in`, and 4 for `txt_in`.
|
||||||
|
|
||||||
|
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.
|
||||||
|
|
||||||
|
### FLUX.1 OFT training
|
||||||
|
|
||||||
|
You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.
|
||||||
|
|
||||||
|
- Change `--network_module` from `networks.lora_flux` to `networks.oft_flux`.
|
||||||
|
- `--network_dim` is the number of OFT blocks. Unlike LoRA rank, the smaller the dim, the larger the model. We recommend about 64 or 128. Please make the output dimension of the target layer of OFT divisible by the value of `--network_dim` (an error will occur if it is not divisible). Valid values are 64, 128, 256, 512, 1024, etc.
|
||||||
|
- `--network_alpha` is treated as a constraint for OFT. We recommend about 1e-2 to 1e-4. The default value when omitted is 1, which is too large, so be sure to specify it.
|
||||||
|
- CLIP/T5XXL is not supported. Specify `--network_train_unet_only`.
|
||||||
|
- `--network_args` specifies the hyperparameters of OFT. The following are valid:
|
||||||
|
- Specify `enable_all_linear=True` to target all linear connections in the MLP layer. The default is False, which targets only attention.
|
||||||
|
|
||||||
|
Currently, there is no environment to infer FLUX.1 OFT. Inference is only possible with `flux_minimal_inference.py` (specify OFT model with `--lora`).
|
||||||
|
|
||||||
|
Sample command is below. It will work with 24GB VRAM GPUs with the batch size of 1.
|
||||||
|
|
||||||
|
```
|
||||||
|
--network_module networks.oft_flux --network_dim 128 --network_alpha 1e-3
|
||||||
|
--network_args "enable_all_linear=True" --learning_rate 1e-5
|
||||||
|
```
|
||||||
|
|
||||||
|
The training can be done with 16GB VRAM GPUs without `--enable_all_linear` option and with Adafactor optimizer.
|
||||||
|
|
||||||
### Inference for FLUX.1 with LoRA model
|
### Inference for FLUX.1 with LoRA model
|
||||||
|
|
||||||
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||||
|
|||||||
@@ -316,6 +316,44 @@ def create_network(
|
|||||||
else:
|
else:
|
||||||
conv_alpha = float(conv_alpha)
|
conv_alpha = float(conv_alpha)
|
||||||
|
|
||||||
|
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
|
||||||
|
img_attn_dim = kwargs.get("img_attn_dim", None)
|
||||||
|
txt_attn_dim = kwargs.get("txt_attn_dim", None)
|
||||||
|
img_mlp_dim = kwargs.get("img_mlp_dim", None)
|
||||||
|
txt_mlp_dim = kwargs.get("txt_mlp_dim", None)
|
||||||
|
img_mod_dim = kwargs.get("img_mod_dim", None)
|
||||||
|
txt_mod_dim = kwargs.get("txt_mod_dim", None)
|
||||||
|
single_dim = kwargs.get("single_dim", None) # SingleStreamBlock
|
||||||
|
single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock
|
||||||
|
if img_attn_dim is not None:
|
||||||
|
img_attn_dim = int(img_attn_dim)
|
||||||
|
if txt_attn_dim is not None:
|
||||||
|
txt_attn_dim = int(txt_attn_dim)
|
||||||
|
if img_mlp_dim is not None:
|
||||||
|
img_mlp_dim = int(img_mlp_dim)
|
||||||
|
if txt_mlp_dim is not None:
|
||||||
|
txt_mlp_dim = int(txt_mlp_dim)
|
||||||
|
if img_mod_dim is not None:
|
||||||
|
img_mod_dim = int(img_mod_dim)
|
||||||
|
if txt_mod_dim is not None:
|
||||||
|
txt_mod_dim = int(txt_mod_dim)
|
||||||
|
if single_dim is not None:
|
||||||
|
single_dim = int(single_dim)
|
||||||
|
if single_mod_dim is not None:
|
||||||
|
single_mod_dim = int(single_mod_dim)
|
||||||
|
type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim]
|
||||||
|
if all([d is None for d in type_dims]):
|
||||||
|
type_dims = None
|
||||||
|
|
||||||
|
# in_dims [img, time, vector, guidance, txt]
|
||||||
|
in_dims = kwargs.get("in_dims", None)
|
||||||
|
if in_dims is not None:
|
||||||
|
in_dims = in_dims.strip()
|
||||||
|
if in_dims.startswith("[") and in_dims.endswith("]"):
|
||||||
|
in_dims = in_dims[1:-1]
|
||||||
|
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
|
||||||
|
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"
|
||||||
|
|
||||||
# rank/module dropout
|
# rank/module dropout
|
||||||
rank_dropout = kwargs.get("rank_dropout", None)
|
rank_dropout = kwargs.get("rank_dropout", None)
|
||||||
if rank_dropout is not None:
|
if rank_dropout is not None:
|
||||||
@@ -339,6 +377,11 @@ def create_network(
|
|||||||
if train_t5xxl is not None:
|
if train_t5xxl is not None:
|
||||||
train_t5xxl = True if train_t5xxl == "True" else False
|
train_t5xxl = True if train_t5xxl == "True" else False
|
||||||
|
|
||||||
|
# verbose
|
||||||
|
verbose = kwargs.get("verbose", False)
|
||||||
|
if verbose is not None:
|
||||||
|
verbose = True if verbose == "True" else False
|
||||||
|
|
||||||
# すごく引数が多いな ( ^ω^)・・・
|
# すごく引数が多いな ( ^ω^)・・・
|
||||||
network = LoRANetwork(
|
network = LoRANetwork(
|
||||||
text_encoders,
|
text_encoders,
|
||||||
@@ -354,7 +397,9 @@ def create_network(
|
|||||||
train_blocks=train_blocks,
|
train_blocks=train_blocks,
|
||||||
split_qkv=split_qkv,
|
split_qkv=split_qkv,
|
||||||
train_t5xxl=train_t5xxl,
|
train_t5xxl=train_t5xxl,
|
||||||
varbose=True,
|
type_dims=type_dims,
|
||||||
|
in_dims=in_dims,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||||
@@ -462,7 +507,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
train_blocks: Optional[str] = None,
|
train_blocks: Optional[str] = None,
|
||||||
split_qkv: bool = False,
|
split_qkv: bool = False,
|
||||||
train_t5xxl: bool = False,
|
train_t5xxl: bool = False,
|
||||||
varbose: Optional[bool] = False,
|
type_dims: Optional[List[int]] = None,
|
||||||
|
in_dims: Optional[List[int]] = None,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
@@ -478,12 +525,17 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.split_qkv = split_qkv
|
self.split_qkv = split_qkv
|
||||||
self.train_t5xxl = train_t5xxl
|
self.train_t5xxl = train_t5xxl
|
||||||
|
|
||||||
|
self.type_dims = type_dims
|
||||||
|
self.in_dims = in_dims
|
||||||
|
|
||||||
self.loraplus_lr_ratio = None
|
self.loraplus_lr_ratio = None
|
||||||
self.loraplus_unet_lr_ratio = None
|
self.loraplus_unet_lr_ratio = None
|
||||||
self.loraplus_text_encoder_lr_ratio = None
|
self.loraplus_text_encoder_lr_ratio = None
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
logger.info(f"create LoRA network from weights")
|
logger.info(f"create LoRA network from weights")
|
||||||
|
self.in_dims = [0] * 5 # create in_dims
|
||||||
|
# verbose = True
|
||||||
else:
|
else:
|
||||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -502,7 +554,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
is_flux: bool, text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str]
|
is_flux: bool,
|
||||||
|
text_encoder_idx: Optional[int],
|
||||||
|
root_module: torch.nn.Module,
|
||||||
|
target_replace_modules: List[str],
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
default_dim: Optional[int] = None,
|
||||||
) -> List[LoRAModule]:
|
) -> List[LoRAModule]:
|
||||||
prefix = (
|
prefix = (
|
||||||
self.LORA_PREFIX_FLUX
|
self.LORA_PREFIX_FLUX
|
||||||
@@ -513,16 +570,22 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
loras = []
|
loras = []
|
||||||
skipped = []
|
skipped = []
|
||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||||
|
if target_replace_modules is None: # dirty hack for all modules
|
||||||
|
module = root_module # search all modules
|
||||||
|
|
||||||
for child_name, child_module in module.named_modules():
|
for child_name, child_module in module.named_modules():
|
||||||
is_linear = child_module.__class__.__name__ == "Linear"
|
is_linear = child_module.__class__.__name__ == "Linear"
|
||||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
|
||||||
if is_linear or is_conv2d:
|
if is_linear or is_conv2d:
|
||||||
lora_name = prefix + "." + name + "." + child_name
|
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||||
lora_name = lora_name.replace(".", "_")
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
|
if filter is not None and not filter in lora_name:
|
||||||
|
continue
|
||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
|
|
||||||
@@ -534,8 +597,25 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
# 通常、すべて対象とする
|
# 通常、すべて対象とする
|
||||||
if is_linear or is_conv2d_1x1:
|
if is_linear or is_conv2d_1x1:
|
||||||
dim = self.lora_dim
|
dim = default_dim if default_dim is not None else self.lora_dim
|
||||||
alpha = self.alpha
|
alpha = self.alpha
|
||||||
|
|
||||||
|
if type_dims is not None:
|
||||||
|
identifier = [
|
||||||
|
("img_attn",),
|
||||||
|
("txt_attn",),
|
||||||
|
("img_mlp",),
|
||||||
|
("txt_mlp",),
|
||||||
|
("img_mod",),
|
||||||
|
("txt_mod",),
|
||||||
|
("single_blocks", "linear"),
|
||||||
|
("modulation",),
|
||||||
|
]
|
||||||
|
for i, d in enumerate(type_dims):
|
||||||
|
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||||
|
dim = d
|
||||||
|
break
|
||||||
|
|
||||||
elif self.conv_lora_dim is not None:
|
elif self.conv_lora_dim is not None:
|
||||||
dim = self.conv_lora_dim
|
dim = self.conv_lora_dim
|
||||||
alpha = self.conv_alpha
|
alpha = self.conv_alpha
|
||||||
@@ -566,6 +646,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
split_dims=split_dims,
|
split_dims=split_dims,
|
||||||
)
|
)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
|
|
||||||
|
if target_replace_modules is None:
|
||||||
|
break # all modules are searched
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
# create LoRA for text encoder
|
# create LoRA for text encoder
|
||||||
@@ -594,10 +677,20 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
|
||||||
|
|
||||||
|
# img, time, vector, guidance, txt
|
||||||
|
if self.in_dims:
|
||||||
|
for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims):
|
||||||
|
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
|
||||||
|
self.unet_loras.extend(loras)
|
||||||
|
|
||||||
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
||||||
|
if verbose:
|
||||||
|
for lora in self.unet_loras:
|
||||||
|
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
|
||||||
|
|
||||||
skipped = skipped_te + skipped_un
|
skipped = skipped_te + skipped_un
|
||||||
if varbose and len(skipped) > 0:
|
if verbose and len(skipped) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user