feat: HunyuanImage LoRA training

This commit is contained in:
Kohya S
2025-09-12 21:40:42 +09:00
parent cbc9e1a3b1
commit 209c02dbb6
12 changed files with 352 additions and 149 deletions

View File

@@ -30,7 +30,7 @@ yos="yos"
wn="wn"
hime="hime"
OT="OT"
byt5="byt5"
byt="byt"
# [files]
# # Extend the default list of files to check

View File

@@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace:
# inference
parser.add_argument(
"--guidance_scale", type=float, default=4.0, help="Guidance scale for classifier free guidance. Default is 4.0."
"--guidance_scale", type=float, default=5.0, help="Guidance scale for classifier free guidance. Default is 5.0."
)
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string")
@@ -508,7 +508,7 @@ def prepare_text_inputs(
prompt = args.prompt
cache_key = prompt
if cache_key in conds_cache:
embed, mask = conds_cache[cache_key]
embed, mask, embed_byt5, mask_byt5, ocr_mask = conds_cache[cache_key]
else:
move_models_to_device_if_needed()
@@ -527,7 +527,7 @@ def prepare_text_inputs(
negative_prompt = args.negative_prompt
cache_key = negative_prompt
if cache_key in conds_cache:
negative_embed, negative_mask = conds_cache[cache_key]
negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5, negative_ocr_mask = conds_cache[cache_key]
else:
move_models_to_device_if_needed()
@@ -614,9 +614,10 @@ def generate(
shared_models["model"] = model
else:
# use shared model
logger.info("Using shared DiT model.")
model: hunyuan_image_models.HYImageDiffusionTransformer = shared_models["model"]
# model.move_to_device_except_swap_blocks(device) # Handles block swap correctly
# model.prepare_block_swap_before_forward()
model.move_to_device_except_swap_blocks(device) # Handles block swap correctly
model.prepare_block_swap_before_forward()
return generate_body(args, model, context, context_null, device, seed)
@@ -678,9 +679,18 @@ def generate_body(
# Denoising loop
do_cfg = args.guidance_scale != 1.0
# print(f"embed shape: {embed.shape}, mean: {embed.mean()}, std: {embed.std()}")
# print(f"embed_byt5 shape: {embed_byt5.shape}, mean: {embed_byt5.mean()}, std: {embed_byt5.std()}")
# print(f"negative_embed shape: {negative_embed.shape}, mean: {negative_embed.mean()}, std: {negative_embed.std()}")
# print(f"negative_embed_byt5 shape: {negative_embed_byt5.shape}, mean: {negative_embed_byt5.mean()}, std: {negative_embed_byt5.std()}")
# print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}")
# print(f"mask shape: {mask.shape}, sum: {mask.sum()}")
# print(f"mask_byt5 shape: {mask_byt5.shape}, sum: {mask_byt5.sum()}")
# print(f"negative_mask shape: {negative_mask.shape}, sum: {negative_mask.sum()}")
# print(f"negative_mask_byt5 shape: {negative_mask_byt5.shape}, sum: {negative_mask_byt5.sum()}")
with tqdm(total=len(timesteps), desc="Denoising steps") as pbar:
for i, t in enumerate(timesteps):
t_expand = t.expand(latents.shape[0]).to(latents.dtype)
t_expand = t.expand(latents.shape[0]).to(torch.int64)
with torch.no_grad():
noise_pred = model(latents, t_expand, embed, mask, embed_byt5, mask_byt5)
@@ -1040,6 +1050,9 @@ def process_interactive(args: argparse.Namespace) -> None:
shared_models = load_shared_models(args)
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae.eval()
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
try:
@@ -1059,9 +1072,6 @@ def process_interactive(args: argparse.Namespace) -> None:
def input_line(prompt: str) -> str:
return input(prompt)
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae.eval()
try:
while True:
try:
@@ -1088,7 +1098,7 @@ def process_interactive(args: argparse.Namespace) -> None:
# Save latent and video
# returned_vae from generate will be used for decoding here.
save_output(prompt_args, vae, latent[0], device)
save_output(prompt_args, vae, latent, device)
except KeyboardInterrupt:
print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")

View File

@@ -1,5 +1,6 @@
import argparse
import copy
import gc
from typing import Any, Optional, Union
import argparse
import os
@@ -12,7 +13,7 @@ import torch.nn as nn
from PIL import Image
from accelerate import Accelerator, PartialState
from library import hunyuan_image_models, hunyuan_image_vae, strategy_base, train_util
from library import flux_utils, hunyuan_image_models, hunyuan_image_vae, strategy_base, train_util
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
@@ -24,7 +25,6 @@ from library import (
hunyuan_image_text_encoder,
hunyuan_image_utils,
hunyuan_image_vae,
sai_model_spec,
sd3_train_utils,
strategy_base,
strategy_hunyuan_image,
@@ -79,8 +79,6 @@ def sample_images(
dit = accelerator.unwrap_model(dit)
if text_encoders is not None:
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
if controlnet is not None:
controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
@@ -162,10 +160,10 @@ def sample_image_inference(
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
cfg_scale = prompt_dict.get("scale", 1.0)
cfg_scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
prompt: str = prompt_dict.get("prompt", "")
flow_shift: float = prompt_dict.get("flow_shift", 4.0)
flow_shift: float = prompt_dict.get("flow_shift", 5.0)
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
@@ -208,11 +206,10 @@ def sample_image_inference(
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
# print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
# print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
@@ -255,16 +252,21 @@ def sample_image_inference(
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
latents = generate_body(gen_args, dit, arg_c, arg_c_null, accelerator.device, seed)
dit_is_training = dit.training
dit.eval()
x = generate_body(gen_args, dit, arg_c, arg_c_null, accelerator.device, seed)
if dit_is_training:
dit.train()
clean_memory_on_device(accelerator.device)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with torch.autocast(accelerator.device.type, vae.dtype, enabled=True), torch.no_grad():
x = x / hunyuan_image_vae.VAE_SCALE_FACTOR
x = vae.decode(x)
with torch.no_grad():
x = x / vae.scaling_factor
x = vae.decode(x.to(vae.device, dtype=vae.dtype))
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
@@ -299,6 +301,7 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False
self.rotary_pos_emb_cache = {}
def assert_extra_args(
self,
@@ -341,12 +344,42 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# currently offload to cpu for some models
vl_dtype = torch.float8_e4m3fn if args.fp8_vl else torch.bfloat16
vl_device = "cpu"
_, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl(
args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
_, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5(
args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
vae.to(dtype=torch.float16) # VAE is always fp16
vae.eval()
if args.vae_enable_tiling:
vae.enable_tiling()
logger.info("VAE tiling is enabled")
model_version = hunyuan_image_utils.MODEL_VERSION_2_1
return model_version, [text_encoder_vlm, text_encoder_byt5], vae, None # unet will be loaded later
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
if args.cache_text_encoder_outputs:
logger.info("Replace text encoders with dummy models to save memory")
# This doesn't free memory, so we move text encoders to meta device in cache_text_encoder_outputs_if_needed
text_encoders = [flux_utils.dummy_clip_l() for _ in text_encoders]
clean_memory_on_device(accelerator.device)
gc.collect()
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
split_attn = True
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
logger.info("xformers is enabled for attention")
model = hunyuan_image_models.load_hunyuan_image_model(
accelerator.device,
@@ -363,19 +396,7 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
vl_dtype = torch.bfloat16
vl_device = "cpu"
_, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl(
args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
_, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5(
args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
model_version = hunyuan_image_utils.MODEL_VERSION_2_1
return model_version, [text_encoder_vlm, text_encoder_byt5], vae, model
return model, text_encoders
def get_tokenize_strategy(self, args):
return strategy_hunyuan_image.HunyuanImageTokenizeStrategy(args.tokenizer_cache_dir)
@@ -404,7 +425,6 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_hunyuan_image.HunyuanImageTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
@@ -417,11 +437,9 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
logger.info("move vae to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoders to gpu")
@@ -457,17 +475,14 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
accelerator.wait_for_everyone()
# move back to cpu
logger.info("move VLM back to cpu")
text_encoders[0].to("cpu")
logger.info("move byT5 back to cpu")
text_encoders[1].to("cpu")
# text encoders are not needed for training, so we move to meta device
logger.info("move text encoders to meta device to save memory")
text_encoders = [te.to("meta") for te in text_encoders]
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
logger.info("move vae back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
@@ -477,21 +492,19 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
sample_images(accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)
def encode_images_to_latents(self, args, vae: hunyuan_image_vae.HunyuanVAE2D, images):
return vae.encode(images).sample()
def shift_scale_latents(self, args, latents):
# for encoding, we need to scale the latents
return latents * hunyuan_image_vae.VAE_SCALE_FACTOR
return latents * hunyuan_image_vae.LATENT_SCALING_FACTOR
def get_noise_pred_and_target(
self,
@@ -509,12 +522,16 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
noisy_model_input, _, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
# bfloat16 is too low precision for 0-1000 TODO fix get_noisy_model_input_and_timesteps
timesteps = (sigmas[:, 0, 0, 0] * 1000).to(torch.int64)
# print(
# f"timestep: {timesteps}, noisy_model_input shape: {noisy_model_input.shape}, mean: {noisy_model_input.mean()}, std: {noisy_model_input.std()}"
# )
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
@@ -526,31 +543,33 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
# ocr_mask is for inference only, so it is not used here
vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = text_encoder_conds
# print(f"embed shape: {vlm_embed.shape}, mean: {vlm_embed.mean()}, std: {vlm_embed.std()}")
# print(f"embed_byt5 shape: {byt5_embed.shape}, mean: {byt5_embed.mean()}, std: {byt5_embed.std()}")
# print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}")
# print(f"mask shape: {vlm_mask.shape}, sum: {vlm_mask.sum()}")
# print(f"mask_byt5 shape: {byt5_mask.shape}, sum: {byt5_mask.sum()}")
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = unet(noisy_model_input, timesteps / 1000, vlm_embed, vlm_mask, byt5_embed, byt5_mask)
model_pred = unet(
noisy_model_input, timesteps, vlm_embed, vlm_mask, byt5_embed, byt5_mask # , self.rotary_pos_emb_cache
)
# model prediction and weighting is omitted for HunyuanImage-2.1 currently
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss
target = noise - latents
# differential output preservation is not used for HunyuanImage-2.1 currently
return model_pred, target, timesteps, None
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
# if self.model_type != "chroma":
# model_description = "schnell" if self.is_schnell else "dev"
# else:
# model_description = "chroma"
# return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)
train_util.get_sai_model_spec_dataclass(None, args, False, True, False, hunyuan_image="2.1")
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, hunyuan_image="2.1").to_metadata_dict()
def update_metadata(self, metadata, args):
metadata["ss_model_type"] = args.model_type
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
@@ -569,6 +588,9 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
def cast_text_encoder(self):
return False # VLM is bf16, byT5 is fp16, so do not cast to other dtype
def cast_vae(self):
return False # VAE is fp16, so do not cast to other dtype
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
# fp8 text encoder for HunyuanImage-2.1 is not supported currently
pass
@@ -597,6 +619,17 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
parser.add_argument(
"--text_encoder",
type=str,
help="path to Qwen2.5-VL (*.sft or *.safetensors), should be bfloat16 / Qwen2.5-VLのパス*.sftまたは*.safetensors、bfloat16が前提",
)
parser.add_argument(
"--byt5",
type=str,
help="path to byt5 (*.sft or *.safetensors), should be float16 / byt5のパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
@@ -613,17 +646,24 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="sigma_scaled",
default="raw",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling). Default is raw unlike FLUX.1."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用",
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用デフォルトはFLUX.1とは異なりrawです。",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
default=5.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 5.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは5.0。",
)
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument("--fp8_vl", action="store_true", help="use fp8 for VLM text encoder / VLMテキストエンコーダにfp8を使用する")
parser.add_argument(
"--vae_enable_tiling",
action="store_true",
help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする",
)
return parser

View File

@@ -1,9 +1,19 @@
import torch
from typing import Optional
from typing import Optional, Union
try:
import xformers.ops as xops
except ImportError:
xops = None
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_lens: list[int], attn_mode: str = "torch", drop_rate: float = 0.0
qkv_or_q: Union[torch.Tensor, list],
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
seq_lens: Optional[list[int]] = None,
attn_mode: str = "torch",
drop_rate: float = 0.0,
) -> torch.Tensor:
"""
Compute scaled dot-product attention with variable sequence lengths.
@@ -12,7 +22,7 @@ def attention(
processing each sequence individually.
Args:
q: Query tensor [B, L, H, D].
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
seq_lens: Valid sequence length for each batch element.
@@ -22,6 +32,17 @@ def attention(
Returns:
Attention output tensor [B, L, H*D].
"""
if isinstance(qkv_or_q, list):
q, k, v = qkv_or_q
qkv_or_q.clear()
del qkv_or_q
else:
q = qkv_or_q
del qkv_or_q
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
if seq_lens is None:
seq_lens = [q.shape[1]] * q.shape[0]
# Determine tensor layout based on attention implementation
if attn_mode == "torch" or attn_mode == "sageattn":
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA
@@ -29,6 +50,7 @@ def attention(
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
# Process each batch element with its valid sequence length
q_seq_len = q.shape[1]
q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))]
k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))]
v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))]
@@ -40,10 +62,24 @@ def attention(
q[i] = None
k[i] = None
v[i] = None
x.append(x_i)
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, q_seq_len - x_i.shape[2]), value=0)) # Pad to max seq len, B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
# Currently only PyTorch SDPA is implemented
elif attn_mode == "xformers":
x = []
for i in range(len(q)):
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, 0, 0, q_seq_len - x_i.shape[1]), value=0)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
else:
# Currently only PyTorch SDPA and xformers are implemented
raise ValueError(f"Unsupported attention mode: {attn_mode}")
x = transpose_fn(x) # [B, L, H, D]
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]

View File

@@ -30,11 +30,7 @@ from library.hunyuan_image_modules import (
from library.hunyuan_image_utils import get_nd_rotary_pos_embed
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
FP8_OPTIMIZATION_EXCLUDE_KEYS = [
"norm",
"_mod",
"modulation",
]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "modulation", "_emb"]
# region DiT Model
@@ -142,6 +138,14 @@ class HYImageDiffusionTransformer(nn.Module):
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload
@@ -273,6 +277,7 @@ class HYImageDiffusionTransformer(nn.Module):
encoder_attention_mask: torch.Tensor,
byt5_text_states: Optional[torch.Tensor] = None,
byt5_text_mask: Optional[torch.Tensor] = None,
rotary_pos_emb_cache: Optional[Dict[Tuple[int, int], Tuple[torch.Tensor, torch.Tensor]]] = None,
) -> torch.Tensor:
"""
Forward pass through the HunyuanImage diffusion transformer.
@@ -296,7 +301,15 @@ class HYImageDiffusionTransformer(nn.Module):
# Calculate spatial dimensions for rotary position embeddings
_, _, oh, ow = x.shape
th, tw = oh, ow # Height and width (patch_size=[1,1] means no spatial downsampling)
freqs_cis = self.get_rotary_pos_embed((th, tw))
if rotary_pos_emb_cache is not None:
if (th, tw) in rotary_pos_emb_cache:
freqs_cis = rotary_pos_emb_cache[(th, tw)]
freqs_cis = (freqs_cis[0].to(img.device), freqs_cis[1].to(img.device))
else:
freqs_cis = self.get_rotary_pos_embed((th, tw))
rotary_pos_emb_cache[(th, tw)] = (freqs_cis[0].cpu(), freqs_cis[1].cpu())
else:
freqs_cis = self.get_rotary_pos_embed((th, tw))
# Reshape image latents to sequence format: [B, C, H, W] -> [B, H*W, C]
img = self.img_in(img)
@@ -349,9 +362,11 @@ class HYImageDiffusionTransformer(nn.Module):
vec = vec.to(input_device)
img = x[:, :img_seq_len, ...]
del x
# Apply final projection to output space
img = self.final_layer(img, vec)
del vec
# Reshape from sequence to spatial format: [B, L, C] -> [B, C, H, W]
img = self.unpatchify_2d(img, th, tw)

View File

@@ -50,7 +50,7 @@ class ByT5Mapper(nn.Module):
Returns:
Transformed embeddings [..., out_dim1].
"""
residual = x
residual = x if self.use_residual else None
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
@@ -411,6 +411,7 @@ class SingleTokenRefiner(nn.Module):
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
del timestep_aware_representations, context_aware_representations
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, txt_lens)
return x
@@ -447,6 +448,7 @@ class FinalLayer(nn.Module):
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
del shift, scale, c
x = self.linear(x)
return x
@@ -494,6 +496,7 @@ class RMSNorm(nn.Module):
Normalized and scaled tensor.
"""
output = self._norm(x.float()).type_as(x)
del x
output = output * self.weight
return output
@@ -634,8 +637,10 @@ class MMDoubleStreamBlock(nn.Module):
# Process image stream for attention
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
del img_mod1_shift, img_mod1_scale
img_qkv = self.img_attn_qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = img_qkv.chunk(3, dim=-1)
del img_qkv
@@ -649,17 +654,15 @@ class MMDoubleStreamBlock(nn.Module):
# Apply rotary position embeddings to image tokens
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}"
img_q, img_k = img_qq, img_kk
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
del freqs_cis
# Process text stream for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.chunk(3, dim=-1)
del txt_qkv
@@ -672,31 +675,44 @@ class MMDoubleStreamBlock(nn.Module):
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# Concatenate image and text tokens for joint attention
img_seq_len = img.shape[1]
q = torch.cat([img_q, txt_q], dim=1)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1)
del img_k, txt_k
v = torch.cat([img_v, txt_v], dim=1)
attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode)
del img_v, txt_v
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
del qkv
# Split attention outputs back to separate streams
img_attn, txt_attn = (attn[:, : img_q.shape[1]].contiguous(), attn[:, img_q.shape[1] :].contiguous())
img_attn, txt_attn = (attn[:, : img_seq_len].contiguous(), attn[:, img_seq_len :].contiguous())
del attn
# Apply attention projection and residual connection for image stream
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
del img_attn, img_mod1_gate
# Apply MLP and residual connection for image stream
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
del img_mod2_shift, img_mod2_scale, img_mod2_gate
# Apply attention projection and residual connection for text stream
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
del txt_attn, txt_mod1_gate
# Apply MLP and residual connection for text stream
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
del txt_mod2_shift, txt_mod2_scale, txt_mod2_gate
return img, txt
@@ -797,6 +813,7 @@ class MMSingleStreamBlock(nn.Module):
# Compute Q, K, V, and MLP input
qkv_mlp = self.linear1(x_mod)
del x_mod
q, k, v, mlp = qkv_mlp.split([self.hidden_size, self.hidden_size, self.hidden_size, self.mlp_hidden_dim], dim=-1)
del qkv_mlp
@@ -810,27 +827,34 @@ class MMSingleStreamBlock(nn.Module):
# Separate image and text tokens
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
del q
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
del k
# img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
# del v
# Apply rotary position embeddings only to image tokens
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}"
img_q, img_k = img_qq, img_kk
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
del freqs_cis
# Recombine and compute joint attention
q = torch.cat([img_q, txt_q], dim=1)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1)
v = torch.cat([img_v, txt_v], dim=1)
attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode)
del img_k, txt_k
# v = torch.cat([img_v, txt_v], dim=1)
# del img_v, txt_v
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
del qkv
# Combine attention and MLP outputs, apply gating
# output = self.linear2(attn, self.mlp_act(mlp))
mlp = self.mlp_act(mlp)
output = torch.cat([attn, mlp], dim=2).contiguous()
del attn, mlp
output = self.linear2(output)
return x + apply_gate(output, gate=mod_gate)

View File

@@ -598,7 +598,7 @@ def get_byt5_prompt_embeds_from_tokens(
) -> Tuple[list[bool], torch.Tensor, torch.Tensor]:
byt5_max_length = BYT5_MAX_LENGTH
if byt5_text_ids is None or byt5_text_mask is None:
if byt5_text_ids is None or byt5_text_mask is None or byt5_text_mask.sum() == 0:
return (
[False],
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),

View File

@@ -17,6 +17,8 @@ logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 32 # 32x spatial compression
LATENT_SCALING_FACTOR = 0.75289 # Latent scaling factor for Hunyuan Image-2.1
def swish(x: Tensor) -> Tensor:
"""Swish activation function: x * sigmoid(x)."""
@@ -378,7 +380,7 @@ class HunyuanVAE2D(nn.Module):
layers_per_block = 2
ffactor_spatial = 32 # 32x spatial compression
sample_size = 384 # Minimum sample size for tiling
scaling_factor = 0.75289 # Latent scaling factor
scaling_factor = LATENT_SCALING_FACTOR # 0.75289 # Latent scaling factor
self.ffactor_spatial = ffactor_spatial
self.scaling_factor = scaling_factor

View File

@@ -21,14 +21,27 @@ class HunyuanImageTokenizeStrategy(TokenizeStrategy):
Qwen2Tokenizer, hunyuan_image_text_encoder.QWEN_2_5_VL_IMAGE_ID, tokenizer_cache_dir=tokenizer_cache_dir
)
self.byt5_tokenizer = self._load_tokenizer(
AutoTokenizer, hunyuan_image_text_encoder.BYT5_TOKENIZER_PATH, tokenizer_cache_dir=tokenizer_cache_dir
AutoTokenizer, hunyuan_image_text_encoder.BYT5_TOKENIZER_PATH, subfolder="", tokenizer_cache_dir=tokenizer_cache_dir
)
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
vlm_tokens, vlm_mask = hunyuan_image_text_encoder.get_qwen_tokens(self.vlm_tokenizer, text)
byt5_tokens, byt5_mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, text)
# byt5_tokens, byt5_mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, text)
byt5_tokens = []
byt5_mask = []
for t in text:
tokens, mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, t)
if tokens is None:
tokens = torch.zeros((1, 1), dtype=torch.long)
mask = torch.zeros((1, 1), dtype=torch.long)
byt5_tokens.append(tokens)
byt5_mask.append(mask)
max_len = max([m.shape[1] for m in byt5_mask])
byt5_tokens = torch.cat([torch.nn.functional.pad(t, (0, max_len - t.shape[1]), value=0) for t in byt5_tokens], dim=0)
byt5_mask = torch.cat([torch.nn.functional.pad(m, (0, max_len - m.shape[1]), value=0) for m in byt5_mask], dim=0)
return [vlm_tokens, vlm_mask, byt5_tokens, byt5_mask]
@@ -46,11 +59,24 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy):
# autocast and no_grad are handled in hunyuan_image_text_encoder
vlm_embed, vlm_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds_from_tokens(qwen2vlm, vlm_tokens, vlm_mask)
ocr_mask, byt5_embed, byt5_mask = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens(
byt5, byt5_tokens, byt5_mask
)
return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask]
# ocr_mask, byt5_embed, byt5_mask = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens(
# byt5, byt5_tokens, byt5_mask
# )
ocr_mask, byt5_embed, byt5_updated_mask = [], [], []
for i in range(byt5_tokens.shape[0]):
ocr_m, byt5_e, byt5_m = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens(
byt5, byt5_tokens[i : i + 1], byt5_mask[i : i + 1]
)
ocr_mask.append(torch.zeros((1,), dtype=torch.long) + (1 if ocr_m[0] else 0)) # 1 or 0
byt5_embed.append(byt5_e)
byt5_updated_mask.append(byt5_m)
ocr_mask = torch.cat(ocr_mask, dim=0).to(torch.bool) # [B]
byt5_embed = torch.cat(byt5_embed, dim=0)
byt5_updated_mask = torch.cat(byt5_updated_mask, dim=0)
return [vlm_embed, vlm_mask, byt5_embed, byt5_updated_mask, ocr_mask]
class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
@@ -110,7 +136,6 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = huyuan_image_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
@@ -124,7 +149,7 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr
vlm_mask = vlm_mask.cpu().numpy()
byt5_embed = byt5_embed.cpu().numpy()
byt5_mask = byt5_mask.cpu().numpy()
ocr_mask = np.array(ocr_mask, dtype=bool)
ocr_mask = ocr_mask.cpu().numpy()
for i, info in enumerate(infos):
vlm_embed_i = vlm_embed[i]
@@ -175,7 +200,13 @@ class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy):
def cache_batch_latents(
self, vae: hunyuan_image_vae.HunyuanVAE2D, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).sample()
# encode_by_vae = lambda img_tensor: vae.encode(img_tensor).sample()
def encode_by_vae(img_tensor):
# no_grad is handled in _default_cache_batch_latents
nonlocal vae
with torch.autocast(device_type=vae.device.type, dtype=vae.dtype):
return vae.encode(img_tensor).sample()
vae_device = vae.device
vae_dtype = vae.dtype

View File

@@ -1744,7 +1744,39 @@ class BaseDataset(torch.utils.data.Dataset):
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None:
return None
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
# old implementation without padding: all elements must have same length
# return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
# new implementation with padding support
result = []
for i in range(len(tensors_list[0])):
tensors = [x[i] for x in tensors_list]
if tensors[0].ndim == 0:
# scalar value: e.g. ocr mask
result.append(torch.stack([converter(x[i]) for x in tensors_list]))
continue
min_len = min([len(x) for x in tensors])
max_len = max([len(x) for x in tensors])
if min_len == max_len:
# no padding
result.append(torch.stack([converter(x) for x in tensors]))
else:
# padding
tensors = [converter(x) for x in tensors]
if tensors[0].ndim == 1:
# input_ids or mask
result.append(
torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])
)
else:
# text encoder outputs
result.append(
torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])
)
return result
# set example
example = {}

View File

@@ -191,9 +191,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
# FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"]
TARGET_REPLACE_MODULE_DOUBLE = ["MMDoubleStreamBlock"]
TARGET_REPLACE_MODULE_SINGLE = ["MMSingleStreamBlock"]
LORA_PREFIX_HUNYUAN_IMAGE_DIT = "lora_unet" # make ComfyUI compatible
@classmethod
@@ -222,7 +221,7 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
nn.Module.__init__(self)
self.multiplier = multiplier
self.lora_dim = lora_dim
@@ -259,8 +258,6 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
# create module instances
def create_modules(
@@ -354,14 +351,14 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork):
# create LoRA for U-Net
target_replace_modules = (
HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE
HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_SINGLE
)
self.unet_loras: List[Union[lora_flux.LoRAModule, lora_flux.LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
self.text_encoder_loras = []
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for HunyuanImage-2.1: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")

View File

@@ -1,3 +1,4 @@
import gc
import importlib
import argparse
import math
@@ -10,11 +11,11 @@ import time
import json
from multiprocessing import Value
import numpy as np
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
@@ -175,7 +176,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
@@ -185,6 +186,9 @@ class NetworkTrainer:
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]:
raise NotImplementedError()
def get_tokenize_strategy(self, args):
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
@@ -476,8 +480,11 @@ class NetworkTrainer:
return loss.mean()
def cast_text_encoder(self):
return True # default for other than HunyuanImage
return True # default for other than HunyuanImage
def cast_vae(self):
return True # default for other than HunyuanImage
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -586,37 +593,18 @@ class NetworkTrainer:
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
vae_dtype = (torch.float32 if args.no_half_vae else weight_dtype) if self.cast_vae() else None
# モデルを読み込む
# load target models: unet may be None for lazy loading
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
if vae_dtype is None:
vae_dtype = vae.dtype
logger.info(f"vae_dtype is set to {vae_dtype} by the model since cast_vae() is false")
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
# prepare dataset for latents caching if needed
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
@@ -643,6 +631,32 @@ class NetworkTrainer:
if val_dataset_group is not None:
self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype)
if unet is None:
# lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory
unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders)
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# prepare network
net_kwargs = {}
if args.network_args is not None:
@@ -672,7 +686,7 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# TODO remove `hasattr` by setting up methods if not defined in the network like below (hacky but will work):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None
@@ -1305,6 +1319,8 @@ class NetworkTrainer:
del t_enc
text_encoders = []
text_encoder = None
gc.collect()
clean_memory_on_device(accelerator.device)
# For --sample_at_first
optimizer_eval_fn()