add support for specifying blocks in FLUX.1 LoRA training

This commit is contained in:
Kohya S
2024-09-16 23:14:09 +09:00
parent 96c677b459
commit d8d15f1a7e
2 changed files with 103 additions and 3 deletions

View File

@@ -11,6 +11,10 @@ The command to install PyTorch is as follows:
### Recent Updates ### Recent Updates
Sep 16, 2024:
Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details.
Sep 15, 2024: Sep 15, 2024:
Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported. Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported.
@@ -54,9 +58,12 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `
- [FLUX.1 LoRA training](#flux1-lora-training) - [FLUX.1 LoRA training](#flux1-lora-training)
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training) - [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
- [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model) - [Distribution of timesteps](#distribution-of-timesteps)
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training) - [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
- [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1)
- [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training)
- [FLUX.1 OFT training](#flux1-oft-training) - [FLUX.1 OFT training](#flux1-oft-training)
- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model)
- [FLUX.1 fine-tuning](#flux1-fine-tuning) - [FLUX.1 fine-tuning](#flux1-fine-tuning)
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning) - [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models) - [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
@@ -239,6 +246,21 @@ Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`. If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.
#### Specify blocks to train in FLUX.1 LoRA training
You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks.
example:
```
--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37"
```
```
--network_args "train_double_block_indices=none" "train_single_block_indices=10-15"
```
If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual.
### FLUX.1 OFT training ### FLUX.1 OFT training
You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different. You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.

View File

@@ -24,6 +24,10 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
class LoRAModule(torch.nn.Module): class LoRAModule(torch.nn.Module):
""" """
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -354,6 +358,50 @@ def create_network(
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"
# double/single train blocks
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
"""
Parse a block selection string and return a list of booleans.
Args:
selection (str): A string specifying which blocks to select.
total_blocks (int): The total number of blocks available.
Returns:
List[bool]: A list of booleans indicating which blocks are selected.
"""
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start = int(start)
end = int(end)
assert 0 <= start < total_blocks, f"invalid start index: {start}"
assert 0 <= end < total_blocks, f"invalid end index: {end}"
assert start <= end, f"invalid range: {start}-{end}"
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks, f"invalid index: {index}"
selected[index] = True
return selected
train_double_block_indices = kwargs.get("train_double_block_indices", None)
train_single_block_indices = kwargs.get("train_single_block_indices", None)
if train_double_block_indices is not None:
train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS)
if train_single_block_indices is not None:
train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS)
# rank/module dropout # rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None) rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None: if rank_dropout is not None:
@@ -399,6 +447,8 @@ def create_network(
train_t5xxl=train_t5xxl, train_t5xxl=train_t5xxl,
type_dims=type_dims, type_dims=type_dims,
in_dims=in_dims, in_dims=in_dims,
train_double_block_indices=train_double_block_indices,
train_single_block_indices=train_single_block_indices,
verbose=verbose, verbose=verbose,
) )
@@ -509,6 +559,8 @@ class LoRANetwork(torch.nn.Module):
train_t5xxl: bool = False, train_t5xxl: bool = False,
type_dims: Optional[List[int]] = None, type_dims: Optional[List[int]] = None,
in_dims: Optional[List[int]] = None, in_dims: Optional[List[int]] = None,
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False, verbose: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -527,6 +579,8 @@ class LoRANetwork(torch.nn.Module):
self.type_dims = type_dims self.type_dims = type_dims
self.in_dims = in_dims self.in_dims = in_dims
self.train_double_block_indices = train_double_block_indices
self.train_single_block_indices = train_single_block_indices
self.loraplus_lr_ratio = None self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None self.loraplus_unet_lr_ratio = None
@@ -600,7 +654,7 @@ class LoRANetwork(torch.nn.Module):
dim = default_dim if default_dim is not None else self.lora_dim dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha alpha = self.alpha
if type_dims is not None: if is_flux and type_dims is not None:
identifier = [ identifier = [
("img_attn",), ("img_attn",),
("txt_attn",), ("txt_attn",),
@@ -613,9 +667,33 @@ class LoRANetwork(torch.nn.Module):
] ]
for i, d in enumerate(type_dims): for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]): if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d dim = d # may be 0 for skip
break break
if (
is_flux
and dim
and (
self.train_double_block_indices is not None
or self.train_single_block_indices is not None
)
and ("double" in lora_name or "single" in lora_name)
):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if (
"double" in lora_name
and self.train_double_block_indices is not None
and not self.train_double_block_indices[block_index]
):
dim = 0
elif (
"single" in lora_name
and self.train_single_block_indices is not None
and not self.train_single_block_indices[block_index]
):
dim = 0
elif self.conv_lora_dim is not None: elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim dim = self.conv_lora_dim
alpha = self.conv_alpha alpha = self.conv_alpha