mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add support for specifying blocks in FLUX.1 LoRA training
This commit is contained in:
@@ -24,6 +24,10 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
NUM_DOUBLE_BLOCKS = 19
|
||||
NUM_SINGLE_BLOCKS = 38
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
@@ -354,6 +358,50 @@ def create_network(
|
||||
in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval?
|
||||
assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)"
|
||||
|
||||
# double/single train blocks
|
||||
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
|
||||
"""
|
||||
Parse a block selection string and return a list of booleans.
|
||||
|
||||
Args:
|
||||
selection (str): A string specifying which blocks to select.
|
||||
total_blocks (int): The total number of blocks available.
|
||||
|
||||
Returns:
|
||||
List[bool]: A list of booleans indicating which blocks are selected.
|
||||
"""
|
||||
if selection == "all":
|
||||
return [True] * total_blocks
|
||||
if selection == "none" or selection == "":
|
||||
return [False] * total_blocks
|
||||
|
||||
selected = [False] * total_blocks
|
||||
ranges = selection.split(",")
|
||||
|
||||
for r in ranges:
|
||||
if "-" in r:
|
||||
start, end = map(str.strip, r.split("-"))
|
||||
start = int(start)
|
||||
end = int(end)
|
||||
assert 0 <= start < total_blocks, f"invalid start index: {start}"
|
||||
assert 0 <= end < total_blocks, f"invalid end index: {end}"
|
||||
assert start <= end, f"invalid range: {start}-{end}"
|
||||
for i in range(start, end + 1):
|
||||
selected[i] = True
|
||||
else:
|
||||
index = int(r)
|
||||
assert 0 <= index < total_blocks, f"invalid index: {index}"
|
||||
selected[index] = True
|
||||
|
||||
return selected
|
||||
|
||||
train_double_block_indices = kwargs.get("train_double_block_indices", None)
|
||||
train_single_block_indices = kwargs.get("train_single_block_indices", None)
|
||||
if train_double_block_indices is not None:
|
||||
train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS)
|
||||
if train_single_block_indices is not None:
|
||||
train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS)
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
@@ -399,6 +447,8 @@ def create_network(
|
||||
train_t5xxl=train_t5xxl,
|
||||
type_dims=type_dims,
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -509,6 +559,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
train_t5xxl: bool = False,
|
||||
type_dims: Optional[List[int]] = None,
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -527,6 +579,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
self.type_dims = type_dims
|
||||
self.in_dims = in_dims
|
||||
self.train_double_block_indices = train_double_block_indices
|
||||
self.train_single_block_indices = train_single_block_indices
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -600,7 +654,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if type_dims is not None:
|
||||
if is_flux and type_dims is not None:
|
||||
identifier = [
|
||||
("img_attn",),
|
||||
("txt_attn",),
|
||||
@@ -613,9 +667,33 @@ class LoRANetwork(torch.nn.Module):
|
||||
]
|
||||
for i, d in enumerate(type_dims):
|
||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||
dim = d
|
||||
dim = d # may be 0 for skip
|
||||
break
|
||||
|
||||
if (
|
||||
is_flux
|
||||
and dim
|
||||
and (
|
||||
self.train_double_block_indices is not None
|
||||
or self.train_single_block_indices is not None
|
||||
)
|
||||
and ("double" in lora_name or "single" in lora_name)
|
||||
):
|
||||
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
|
||||
block_index = int(lora_name.split("_")[4]) # bit dirty
|
||||
if (
|
||||
"double" in lora_name
|
||||
and self.train_double_block_indices is not None
|
||||
and not self.train_double_block_indices[block_index]
|
||||
):
|
||||
dim = 0
|
||||
elif (
|
||||
"single" in lora_name
|
||||
and self.train_single_block_indices is not None
|
||||
and not self.train_single_block_indices[block_index]
|
||||
):
|
||||
dim = 0
|
||||
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
|
||||
Reference in New Issue
Block a user