mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add experimental split mode for FLUX
This commit is contained in:
@@ -252,6 +252,11 @@ def create_network(
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# single or double blocks
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||
if train_blocks is not None:
|
||||
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -264,6 +269,7 @@ def create_network(
|
||||
module_dropout=module_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_blocks=train_blocks,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
@@ -314,9 +320,11 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
|
||||
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
|
||||
|
||||
@@ -335,6 +343,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
module_class: Type[object] = LoRAModule,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_blocks: Optional[str] = None,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -347,6 +356,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.train_blocks = train_blocks if train_blocks is not None else "all"
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -360,7 +370,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -434,9 +446,17 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# create LoRA for U-Net
|
||||
if self.train_blocks == "all":
|
||||
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
|
||||
elif self.train_blocks == "single":
|
||||
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
|
||||
elif self.train_blocks == "double":
|
||||
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
|
||||
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
|
||||
Reference in New Issue
Block a user