feat: HunyuanImage LoRA training

This commit is contained in:
Kohya S
2025-09-12 21:40:42 +09:00
parent cbc9e1a3b1
commit 209c02dbb6
12 changed files with 352 additions and 149 deletions

View File

@@ -191,9 +191,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
TARGET_REPLACE_MODULE_DOUBLE = ["MMDoubleStreamBlock"]
TARGET_REPLACE_MODULE_SINGLE = ["MMSingleStreamBlock"]
LORA_PREFIX_HUNYUAN_IMAGE_DIT = "lora_unet" # make ComfyUI compatible
@classmethod
@@ -222,7 +221,7 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
nn.Module.__init__(self)
self.multiplier = multiplier
self.lora_dim = lora_dim
@@ -259,8 +258,6 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
# create module instances
def create_modules(
@@ -354,14 +351,14 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
# create LoRA for U-Net
target_replace_modules = (
HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_SINGLE
)
self.unet_loras: List[Union[lora_flux.LoRAModule, lora_flux.LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
self.text_encoder_loras = []
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for HunyuanImage-2.1: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")