mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: block swap for inference and initial impl for HunyuanImage LoRA (not working)
This commit is contained in:
@@ -5,6 +5,18 @@ import math
|
||||
from typing import Tuple, Union, Optional
|
||||
import torch
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_VERSION_2_1 = "hunyuan-image-2.1"
|
||||
|
||||
# region model
|
||||
|
||||
|
||||
def _to_tuple(x, dim=2):
|
||||
"""
|
||||
@@ -206,7 +218,7 @@ def reshape_for_broadcast(
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}"
|
||||
|
||||
|
||||
shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
|
||||
@@ -248,7 +260,7 @@ def apply_rotary_emb(
|
||||
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
||||
cos, sin = cos.to(device), sin.to(device)
|
||||
|
||||
|
||||
# Apply rotation: x' = x * cos + rotate_half(x) * sin
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype)
|
||||
@@ -256,6 +268,11 @@ def apply_rotary_emb(
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region inference
|
||||
|
||||
|
||||
def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate timesteps and sigmas for diffusion sampling.
|
||||
@@ -291,6 +308,9 @@ def step(latents, noise_pred, sigmas, step_i):
|
||||
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region AdaptiveProjectedGuidance
|
||||
|
||||
|
||||
@@ -298,6 +318,7 @@ class MomentumBuffer:
|
||||
"""
|
||||
Exponential moving average buffer for APG momentum.
|
||||
"""
|
||||
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
@@ -318,10 +339,10 @@ def normalized_guidance_apg(
|
||||
):
|
||||
"""
|
||||
Apply normalized adaptive projected guidance.
|
||||
|
||||
|
||||
Projects the guidance vector to reduce over-saturation while maintaining
|
||||
directional control by decomposing into parallel and orthogonal components.
|
||||
|
||||
|
||||
Args:
|
||||
pred_cond: Conditional prediction.
|
||||
pred_uncond: Unconditional prediction.
|
||||
@@ -330,7 +351,7 @@ def normalized_guidance_apg(
|
||||
eta: Scaling factor for parallel component.
|
||||
norm_threshold: Maximum norm for guidance vector clipping.
|
||||
use_original_formulation: Whether to use original APG formulation.
|
||||
|
||||
|
||||
Returns:
|
||||
Guided prediction tensor.
|
||||
"""
|
||||
@@ -366,10 +387,11 @@ def normalized_guidance_apg(
|
||||
class AdaptiveProjectedGuidance:
|
||||
"""
|
||||
Adaptive Projected Guidance for classifier-free guidance.
|
||||
|
||||
|
||||
Implements APG which projects the guidance vector to reduce over-saturation
|
||||
while maintaining directional control.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -406,9 +428,6 @@ class AdaptiveProjectedGuidance:
|
||||
return pred
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def apply_classifier_free_guidance(
|
||||
noise_pred_text: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
@@ -459,3 +478,6 @@ def apply_classifier_free_guidance(
|
||||
noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
|
||||
|
||||
return noise_pred
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user