mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
feat: HunyuanImage LoRA training
This commit is contained in:
@@ -191,9 +191,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
|
||||
|
||||
class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
|
||||
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
|
||||
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
|
||||
TARGET_REPLACE_MODULE_DOUBLE = ["MMDoubleStreamBlock"]
|
||||
TARGET_REPLACE_MODULE_SINGLE = ["MMSingleStreamBlock"]
|
||||
LORA_PREFIX_HUNYUAN_IMAGE_DIT = "lora_unet" # make ComfyUI compatible
|
||||
|
||||
@classmethod
|
||||
@@ -222,7 +221,7 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
nn.Module.__init__(self)
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
@@ -259,8 +258,6 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
|
||||
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -354,14 +351,14 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
|
||||
|
||||
# create LoRA for U-Net
|
||||
target_replace_modules = (
|
||||
HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
|
||||
HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_SINGLE
|
||||
)
|
||||
|
||||
self.unet_loras: List[Union[lora_flux.LoRAModule, lora_flux.LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
|
||||
self.text_encoder_loras = []
|
||||
|
||||
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
||||
logger.info(f"create LoRA for HunyuanImage-2.1: {len(self.unet_loras)} modules.")
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
Reference in New Issue
Block a user