feat: add Adaptive Projected Guidance parameters and noise rescaling

This commit is contained in:
Kohya S
2025-09-21 12:34:40 +09:00
parent e7b8e9a778
commit 9621d9d637
2 changed files with 64 additions and 2 deletions

View File

@@ -69,6 +69,24 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier free guidance. Default is 3.5."
)
parser.add_argument(
"--apg_start_step_ocr",
type=int,
default=38,
help="Starting step for Adaptive Projected Guidance (APG) for image with text. Default is 38. Should be less than infer_steps, usually near the end.",
)
parser.add_argument(
"--apg_start_step_general",
type=int,
default=5,
help="Starting step for Adaptive Projected Guidance (APG) for general image. Default is 5. Should be less than infer_steps, usually near the beginning.",
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help="Guidance rescale factor for steps without APG, 0.0 to 1.0. Default is 0.0 (no rescale)."
)
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string")
parser.add_argument("--image_size", type=int, nargs=2, default=[2048, 2048], help="image size, height and width")
@@ -715,8 +733,11 @@ def generate_body(
ocr_mask[0],
args.guidance_scale,
i,
apg_start_step_ocr=args.apg_start_step_ocr,
apg_start_step_general=args.apg_start_step_general,
cfg_guider_ocr=cfg_guider_ocr,
cfg_guider_general=cfg_guider_general,
guidance_rescale=args.guidance_rescale,
)
# ensure latents dtype is consistent

View File

@@ -428,16 +428,52 @@ class AdaptiveProjectedGuidance:
return pred
def rescale_noise_cfg(guided_noise, conditional_noise, rescale_factor=0.0):
"""
Rescale guided noise prediction to prevent overexposure and improve image quality.
This implementation addresses the overexposure issue described in "Common Diffusion Noise
Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf) (Section 3.4).
The rescaling preserves the statistical properties of the conditional prediction while reducing artifacts.
Args:
guided_noise (torch.Tensor): Noise prediction from classifier-free guidance.
conditional_noise (torch.Tensor): Noise prediction from conditional model.
rescale_factor (float): Interpolation factor between original and rescaled predictions.
0.0 = no rescaling, 1.0 = full rescaling.
Returns:
torch.Tensor: Rescaled noise prediction with reduced overexposure.
"""
if rescale_factor == 0.0:
return guided_noise
# Calculate standard deviation across spatial dimensions for both predictions
spatial_dims = list(range(1, conditional_noise.ndim))
conditional_std = conditional_noise.std(dim=spatial_dims, keepdim=True)
guided_std = guided_noise.std(dim=spatial_dims, keepdim=True)
# Rescale guided noise to match conditional noise statistics
std_ratio = conditional_std / guided_std
rescaled_prediction = guided_noise * std_ratio
# Interpolate between original and rescaled predictions
final_prediction = rescale_factor * rescaled_prediction + (1.0 - rescale_factor) * guided_noise
return final_prediction
def apply_classifier_free_guidance(
noise_pred_text: torch.Tensor,
noise_pred_uncond: torch.Tensor,
is_ocr: bool,
guidance_scale: float,
step: int,
apg_start_step_ocr: int = 75,
apg_start_step_general: int = 10,
apg_start_step_ocr: int = 38,
apg_start_step_general: int = 5,
cfg_guider_ocr: AdaptiveProjectedGuidance = None,
cfg_guider_general: AdaptiveProjectedGuidance = None,
guidance_rescale: float = 0.0,
):
"""
Apply classifier-free guidance with OCR-aware APG for batch_size=1.
@@ -471,6 +507,11 @@ def apply_classifier_free_guidance(
if step <= apg_start_step:
# Standard classifier-free guidance
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
# Initialize APG guider state
_ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
else: