add experimental split mode for FLUX

This commit is contained in:
Kohya S
2024-08-13 22:28:39 +09:00
parent 9711c96f96
commit 56d7651f08
4 changed files with 304 additions and 23 deletions

View File

@@ -252,6 +252,11 @@ def create_network(
if module_dropout is not None:
module_dropout = float(module_dropout)
# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
if train_blocks is not None:
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
@@ -264,6 +269,7 @@ def create_network(
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
varbose=True,
)
@@ -314,9 +320,11 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
class LoRANetwork(torch.nn.Module):
FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te2"
@@ -335,6 +343,7 @@ class LoRANetwork(torch.nn.Module):
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -347,6 +356,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -360,7 +370,9 @@ class LoRANetwork(torch.nn.Module):
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances
def create_modules(
@@ -434,9 +446,17 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# create LoRA for U-Net
if self.train_blocks == "all":
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
elif self.train_blocks == "single":
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
elif self.train_blocks == "double":
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.FLUX_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0: