mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
feat: HunyuanImage LoRA training
This commit is contained in:
@@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace:
|
||||
|
||||
# inference
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=4.0, help="Guidance scale for classifier free guidance. Default is 4.0."
|
||||
"--guidance_scale", type=float, default=5.0, help="Guidance scale for classifier free guidance. Default is 5.0."
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
|
||||
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string")
|
||||
@@ -508,7 +508,7 @@ def prepare_text_inputs(
|
||||
prompt = args.prompt
|
||||
cache_key = prompt
|
||||
if cache_key in conds_cache:
|
||||
embed, mask = conds_cache[cache_key]
|
||||
embed, mask, embed_byt5, mask_byt5, ocr_mask = conds_cache[cache_key]
|
||||
else:
|
||||
move_models_to_device_if_needed()
|
||||
|
||||
@@ -527,7 +527,7 @@ def prepare_text_inputs(
|
||||
negative_prompt = args.negative_prompt
|
||||
cache_key = negative_prompt
|
||||
if cache_key in conds_cache:
|
||||
negative_embed, negative_mask = conds_cache[cache_key]
|
||||
negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5, negative_ocr_mask = conds_cache[cache_key]
|
||||
else:
|
||||
move_models_to_device_if_needed()
|
||||
|
||||
@@ -614,9 +614,10 @@ def generate(
|
||||
shared_models["model"] = model
|
||||
else:
|
||||
# use shared model
|
||||
logger.info("Using shared DiT model.")
|
||||
model: hunyuan_image_models.HYImageDiffusionTransformer = shared_models["model"]
|
||||
# model.move_to_device_except_swap_blocks(device) # Handles block swap correctly
|
||||
# model.prepare_block_swap_before_forward()
|
||||
model.move_to_device_except_swap_blocks(device) # Handles block swap correctly
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
return generate_body(args, model, context, context_null, device, seed)
|
||||
|
||||
@@ -678,9 +679,18 @@ def generate_body(
|
||||
|
||||
# Denoising loop
|
||||
do_cfg = args.guidance_scale != 1.0
|
||||
# print(f"embed shape: {embed.shape}, mean: {embed.mean()}, std: {embed.std()}")
|
||||
# print(f"embed_byt5 shape: {embed_byt5.shape}, mean: {embed_byt5.mean()}, std: {embed_byt5.std()}")
|
||||
# print(f"negative_embed shape: {negative_embed.shape}, mean: {negative_embed.mean()}, std: {negative_embed.std()}")
|
||||
# print(f"negative_embed_byt5 shape: {negative_embed_byt5.shape}, mean: {negative_embed_byt5.mean()}, std: {negative_embed_byt5.std()}")
|
||||
# print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}")
|
||||
# print(f"mask shape: {mask.shape}, sum: {mask.sum()}")
|
||||
# print(f"mask_byt5 shape: {mask_byt5.shape}, sum: {mask_byt5.sum()}")
|
||||
# print(f"negative_mask shape: {negative_mask.shape}, sum: {negative_mask.sum()}")
|
||||
# print(f"negative_mask_byt5 shape: {negative_mask_byt5.shape}, sum: {negative_mask_byt5.sum()}")
|
||||
with tqdm(total=len(timesteps), desc="Denoising steps") as pbar:
|
||||
for i, t in enumerate(timesteps):
|
||||
t_expand = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
t_expand = t.expand(latents.shape[0]).to(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
noise_pred = model(latents, t_expand, embed, mask, embed_byt5, mask_byt5)
|
||||
@@ -1040,6 +1050,9 @@ def process_interactive(args: argparse.Namespace) -> None:
|
||||
shared_models = load_shared_models(args)
|
||||
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
|
||||
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae.eval()
|
||||
|
||||
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
|
||||
|
||||
try:
|
||||
@@ -1059,9 +1072,6 @@ def process_interactive(args: argparse.Namespace) -> None:
|
||||
def input_line(prompt: str) -> str:
|
||||
return input(prompt)
|
||||
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae.eval()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
@@ -1088,7 +1098,7 @@ def process_interactive(args: argparse.Namespace) -> None:
|
||||
|
||||
# Save latent and video
|
||||
# returned_vae from generate will be used for decoding here.
|
||||
save_output(prompt_args, vae, latent[0], device)
|
||||
save_output(prompt_args, vae, latent, device)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
|
||||
|
||||
Reference in New Issue
Block a user