feat: HunyuanImage LoRA training

This commit is contained in:
Kohya S
2025-09-12 21:40:42 +09:00
parent cbc9e1a3b1
commit 209c02dbb6
12 changed files with 352 additions and 149 deletions

View File

@@ -1,3 +1,4 @@
import gc
import importlib
import argparse
import math
@@ -10,11 +11,11 @@ import time
import json
from multiprocessing import Value
import numpy as np
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
@@ -175,7 +176,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
@@ -185,6 +186,9 @@ class NetworkTrainer:
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]:
raise NotImplementedError()
def get_tokenize_strategy(self, args):
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
@@ -476,8 +480,11 @@ class NetworkTrainer:
return loss.mean()
def cast_text_encoder(self):
return True # default for other than HunyuanImage
return True # default for other than HunyuanImage
def cast_vae(self):
return True # default for other than HunyuanImage
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -586,37 +593,18 @@ class NetworkTrainer:
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
vae_dtype = (torch.float32 if args.no_half_vae else weight_dtype) if self.cast_vae() else None
# モデルを読み込む
# load target models: unet may be None for lazy loading
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
if vae_dtype is None:
vae_dtype = vae.dtype
logger.info(f"vae_dtype is set to {vae_dtype} by the model since cast_vae() is false")
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
# prepare dataset for latents caching if needed
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
@@ -643,6 +631,32 @@ class NetworkTrainer:
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
if unet is None:
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# prepare network
net_kwargs = {}
if args.network_args is not None:
@@ -672,7 +686,7 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# TODO remove `hasattr` by setting up methods if not defined in the network like below (hacky but will work):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None
@@ -1305,6 +1319,8 @@ class NetworkTrainer:
del t_enc
text_encoders = []
text_encoder = None
gc.collect()
clean_memory_on_device(accelerator.device)
# For --sample_at_first
optimizer_eval_fn()