mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
feat: HunyuanImage LoRA training
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import importlib
|
||||
import argparse
|
||||
import math
|
||||
@@ -10,11 +11,11 @@ import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.types import Number
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
@@ -175,7 +176,7 @@ class NetworkTrainer:
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
|
||||
def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]:
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
@@ -185,6 +186,9 @@ class NetworkTrainer:
|
||||
|
||||
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
||||
|
||||
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
@@ -476,8 +480,11 @@ class NetworkTrainer:
|
||||
return loss.mean()
|
||||
|
||||
def cast_text_encoder(self):
|
||||
return True # default for other than HunyuanImage
|
||||
return True # default for other than HunyuanImage
|
||||
|
||||
def cast_vae(self):
|
||||
return True # default for other than HunyuanImage
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
@@ -586,37 +593,18 @@ class NetworkTrainer:
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
vae_dtype = (torch.float32 if args.no_half_vae else weight_dtype) if self.cast_vae() else None
|
||||
|
||||
# モデルを読み込む
|
||||
# load target models: unet may be None for lazy loading
|
||||
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||
if vae_dtype is None:
|
||||
vae_dtype = vae.dtype
|
||||
logger.info(f"vae_dtype is set to {vae_dtype} by the model since cast_vae() is false")
|
||||
|
||||
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
accelerator.print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_weights is not None:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
|
||||
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
|
||||
|
||||
module, weights_sd = network_module.create_network_from_weights(
|
||||
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
||||
)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# 学習を準備する
|
||||
# prepare dataset for latents caching if needed
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
@@ -643,6 +631,32 @@ class NetworkTrainer:
|
||||
if val_dataset_group is not None:
|
||||
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
|
||||
|
||||
if unet is None:
|
||||
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
|
||||
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
|
||||
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
accelerator.print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_weights is not None:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
|
||||
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
|
||||
|
||||
module, weights_sd = network_module.create_network_from_weights(
|
||||
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
||||
)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
@@ -672,7 +686,7 @@ class NetworkTrainer:
|
||||
return
|
||||
network_has_multiplier = hasattr(network, "set_multiplier")
|
||||
|
||||
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
|
||||
# TODO remove `hasattr` by setting up methods if not defined in the network like below (hacky but will work):
|
||||
# if not hasattr(network, "prepare_network"):
|
||||
# network.prepare_network = lambda args: None
|
||||
|
||||
@@ -1305,6 +1319,8 @@ class NetworkTrainer:
|
||||
del t_enc
|
||||
text_encoders = []
|
||||
text_encoder = None
|
||||
gc.collect()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
|
||||
Reference in New Issue
Block a user