mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add support for specifying blocks in FLUX.1 LoRA training
This commit is contained in:
24
README.md
24
README.md
@@ -11,6 +11,10 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 16, 2024:
|
||||||
|
|
||||||
|
Added `train_double_block_indices` and `train_double_block_indices` to the LoRA training script to specify the indices of the blocks to train. See [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training) for details.
|
||||||
|
|
||||||
Sep 15, 2024:
|
Sep 15, 2024:
|
||||||
|
|
||||||
Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported.
|
Added a script `convert_diffusers_to_flux.py` to convert Diffusers format FLUX.1 models (checkpoints) to BFL format. See `--help` for usage. Only Flux models are supported. AE/CLIP/T5XXL are not supported.
|
||||||
@@ -54,9 +58,12 @@ Please update `safetensors` to `0.4.4` to fix the error when using `--resume`. `
|
|||||||
|
|
||||||
- [FLUX.1 LoRA training](#flux1-lora-training)
|
- [FLUX.1 LoRA training](#flux1-lora-training)
|
||||||
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
|
- [Key Options for FLUX.1 LoRA training](#key-options-for-flux1-lora-training)
|
||||||
- [Inference for FLUX.1 LoRA model](#inference-for-flux1-lora-model)
|
- [Distribution of timesteps](#distribution-of-timesteps)
|
||||||
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
|
- [Key Features for FLUX.1 LoRA training](#key-features-for-flux1-lora-training)
|
||||||
|
- [Specify rank for each layer in FLUX.1](#specify-rank-for-each-layer-in-flux1)
|
||||||
|
- [Specify blocks to train in FLUX.1 LoRA training](#specify-blocks-to-train-in-flux1-lora-training)
|
||||||
- [FLUX.1 OFT training](#flux1-oft-training)
|
- [FLUX.1 OFT training](#flux1-oft-training)
|
||||||
|
- [Inference for FLUX.1 with LoRA model](#inference-for-flux1-with-lora-model)
|
||||||
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
|
- [FLUX.1 fine-tuning](#flux1-fine-tuning)
|
||||||
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
|
- [Key Features for FLUX.1 fine-tuning](#key-features-for-flux1-fine-tuning)
|
||||||
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
|
- [Extract LoRA from FLUX.1 Models](#extract-lora-from-flux1-models)
|
||||||
@@ -239,6 +246,21 @@ Each number corresponds to `img_in`, `time_in`, `vector_in`, `guidance_in`, `txt
|
|||||||
|
|
||||||
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.
|
If you specify `0`, LoRA will not be applied to that layer. For example, `[4,0,0,0,4]` applies LoRA only to `img_in` and `txt_in`.
|
||||||
|
|
||||||
|
#### Specify blocks to train in FLUX.1 LoRA training
|
||||||
|
|
||||||
|
You can specify the blocks to train in FLUX.1 LoRA training by specifying `train_double_block_indices` and `train_single_block_indices` in network_args. The indices are 0-based. The default (when omitted) is to train all blocks. The indices are specified as a list of integers or a range of integers, like `0,1,5,8` or `0,1,4-5,7`. The number of double blocks is 19, and the number of single blocks is 38, so the valid range is 0-18 and 0-37, respectively. `all` is also available to train all blocks, `none` is also available to train no blocks.
|
||||||
|
|
||||||
|
example:
|
||||||
|
```
|
||||||
|
--network_args "train_double_block_indices=0,1,8-12,18" "train_single_block_indices=3,10,20-25,37"
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
--network_args "train_double_block_indices=none" "train_single_block_indices=10-15"
|
||||||
|
```
|
||||||
|
|
||||||
|
If you specify one of `train_double_block_indices` or `train_single_block_indices`, the other will be trained as usual.
|
||||||
|
|
||||||
### FLUX.1 OFT training
|
### FLUX.1 OFT training
|
||||||
|
|
||||||
You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.
|
You can train OFT with almost the same options as LoRA, such as `--timestamp_sampling`. The following points are different.
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
NUM_DOUBLE_BLOCKS = 19
|
||||||
|
NUM_SINGLE_BLOCKS = 38
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
@@ -354,6 +358,50 @@ def create_network(
|
|||||||
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
|
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
|
||||||
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"
|
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"
|
||||||
|
|
||||||
|
# double/single train blocks
|
||||||
|
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
|
||||||
|
"""
|
||||||
|
Parse a block selection string and return a list of booleans.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selection (str): A string specifying which blocks to select.
|
||||||
|
total_blocks (int): The total number of blocks available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[bool]: A list of booleans indicating which blocks are selected.
|
||||||
|
"""
|
||||||
|
if selection == "all":
|
||||||
|
return [True] * total_blocks
|
||||||
|
if selection == "none" or selection == "":
|
||||||
|
return [False] * total_blocks
|
||||||
|
|
||||||
|
selected = [False] * total_blocks
|
||||||
|
ranges = selection.split(",")
|
||||||
|
|
||||||
|
for r in ranges:
|
||||||
|
if "-" in r:
|
||||||
|
start, end = map(str.strip, r.split("-"))
|
||||||
|
start = int(start)
|
||||||
|
end = int(end)
|
||||||
|
assert 0 <= start < total_blocks, f"invalid start index: {start}"
|
||||||
|
assert 0 <= end < total_blocks, f"invalid end index: {end}"
|
||||||
|
assert start <= end, f"invalid range: {start}-{end}"
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
selected[i] = True
|
||||||
|
else:
|
||||||
|
index = int(r)
|
||||||
|
assert 0 <= index < total_blocks, f"invalid index: {index}"
|
||||||
|
selected[index] = True
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
train_double_block_indices = kwargs.get("train_double_block_indices", None)
|
||||||
|
train_single_block_indices = kwargs.get("train_single_block_indices", None)
|
||||||
|
if train_double_block_indices is not None:
|
||||||
|
train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS)
|
||||||
|
if train_single_block_indices is not None:
|
||||||
|
train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS)
|
||||||
|
|
||||||
# rank/module dropout
|
# rank/module dropout
|
||||||
rank_dropout = kwargs.get("rank_dropout", None)
|
rank_dropout = kwargs.get("rank_dropout", None)
|
||||||
if rank_dropout is not None:
|
if rank_dropout is not None:
|
||||||
@@ -399,6 +447,8 @@ def create_network(
|
|||||||
train_t5xxl=train_t5xxl,
|
train_t5xxl=train_t5xxl,
|
||||||
type_dims=type_dims,
|
type_dims=type_dims,
|
||||||
in_dims=in_dims,
|
in_dims=in_dims,
|
||||||
|
train_double_block_indices=train_double_block_indices,
|
||||||
|
train_single_block_indices=train_single_block_indices,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -509,6 +559,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
train_t5xxl: bool = False,
|
train_t5xxl: bool = False,
|
||||||
type_dims: Optional[List[int]] = None,
|
type_dims: Optional[List[int]] = None,
|
||||||
in_dims: Optional[List[int]] = None,
|
in_dims: Optional[List[int]] = None,
|
||||||
|
train_double_block_indices: Optional[List[bool]] = None,
|
||||||
|
train_single_block_indices: Optional[List[bool]] = None,
|
||||||
verbose: Optional[bool] = False,
|
verbose: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -527,6 +579,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
self.type_dims = type_dims
|
self.type_dims = type_dims
|
||||||
self.in_dims = in_dims
|
self.in_dims = in_dims
|
||||||
|
self.train_double_block_indices = train_double_block_indices
|
||||||
|
self.train_single_block_indices = train_single_block_indices
|
||||||
|
|
||||||
self.loraplus_lr_ratio = None
|
self.loraplus_lr_ratio = None
|
||||||
self.loraplus_unet_lr_ratio = None
|
self.loraplus_unet_lr_ratio = None
|
||||||
@@ -600,7 +654,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
dim = default_dim if default_dim is not None else self.lora_dim
|
dim = default_dim if default_dim is not None else self.lora_dim
|
||||||
alpha = self.alpha
|
alpha = self.alpha
|
||||||
|
|
||||||
if type_dims is not None:
|
if is_flux and type_dims is not None:
|
||||||
identifier = [
|
identifier = [
|
||||||
("img_attn",),
|
("img_attn",),
|
||||||
("txt_attn",),
|
("txt_attn",),
|
||||||
@@ -613,9 +667,33 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
for i, d in enumerate(type_dims):
|
for i, d in enumerate(type_dims):
|
||||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||||
dim = d
|
dim = d # may be 0 for skip
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if (
|
||||||
|
is_flux
|
||||||
|
and dim
|
||||||
|
and (
|
||||||
|
self.train_double_block_indices is not None
|
||||||
|
or self.train_single_block_indices is not None
|
||||||
|
)
|
||||||
|
and ("double" in lora_name or "single" in lora_name)
|
||||||
|
):
|
||||||
|
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||||
|
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||||
|
if (
|
||||||
|
"double" in lora_name
|
||||||
|
and self.train_double_block_indices is not None
|
||||||
|
and not self.train_double_block_indices[block_index]
|
||||||
|
):
|
||||||
|
dim = 0
|
||||||
|
elif (
|
||||||
|
"single" in lora_name
|
||||||
|
and self.train_single_block_indices is not None
|
||||||
|
and not self.train_single_block_indices[block_index]
|
||||||
|
):
|
||||||
|
dim = 0
|
||||||
|
|
||||||
elif self.conv_lora_dim is not None:
|
elif self.conv_lora_dim is not None:
|
||||||
dim = self.conv_lora_dim
|
dim = self.conv_lora_dim
|
||||||
alpha = self.conv_alpha
|
alpha = self.conv_alpha
|
||||||
|
|||||||
Reference in New Issue
Block a user