feat: block swap for inference and initial impl for HunyuanImage LoRA (not working)

This commit is contained in:
Kohya S
2025-09-11 22:15:22 +09:00
parent 5149be5a87
commit 7f983c558d
16 changed files with 1363 additions and 1303 deletions

View File

@@ -5,6 +5,18 @@ import math
from typing import Tuple, Union, Optional
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MODEL_VERSION_2_1 = "hunyuan-image-2.1"
# region model
def _to_tuple(x, dim=2):
"""
@@ -206,7 +218,7 @@ def reshape_for_broadcast(
x.shape[1],
x.shape[-1],
), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}"
shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
@@ -248,7 +260,7 @@ def apply_rotary_emb(
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(device), sin.to(device)
# Apply rotation: x' = x * cos + rotate_half(x) * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype)
@@ -256,6 +268,11 @@ def apply_rotary_emb(
return xq_out, xk_out
# endregion
# region inference
def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate timesteps and sigmas for diffusion sampling.
@@ -291,6 +308,9 @@ def step(latents, noise_pred, sigmas, step_i):
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
# endregion
# region AdaptiveProjectedGuidance
@@ -298,6 +318,7 @@ class MomentumBuffer:
"""
Exponential moving average buffer for APG momentum.
"""
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
@@ -318,10 +339,10 @@ def normalized_guidance_apg(
):
"""
Apply normalized adaptive projected guidance.
Projects the guidance vector to reduce over-saturation while maintaining
directional control by decomposing into parallel and orthogonal components.
Args:
pred_cond: Conditional prediction.
pred_uncond: Unconditional prediction.
@@ -330,7 +351,7 @@ def normalized_guidance_apg(
eta: Scaling factor for parallel component.
norm_threshold: Maximum norm for guidance vector clipping.
use_original_formulation: Whether to use original APG formulation.
Returns:
Guided prediction tensor.
"""
@@ -366,10 +387,11 @@ def normalized_guidance_apg(
class AdaptiveProjectedGuidance:
"""
Adaptive Projected Guidance for classifier-free guidance.
Implements APG which projects the guidance vector to reduce over-saturation
while maintaining directional control.
"""
def __init__(
self,
guidance_scale: float = 7.5,
@@ -406,9 +428,6 @@ class AdaptiveProjectedGuidance:
return pred
# endregion
def apply_classifier_free_guidance(
noise_pred_text: torch.Tensor,
noise_pred_uncond: torch.Tensor,
@@ -459,3 +478,6 @@ def apply_classifier_free_guidance(
noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
return noise_pred
# endregion