mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
feat: add Adaptive Projected Guidance parameters and noise rescaling
This commit is contained in:
@@ -69,6 +69,24 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier free guidance. Default is 3.5."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apg_start_step_ocr",
|
||||
type=int,
|
||||
default=38,
|
||||
help="Starting step for Adaptive Projected Guidance (APG) for image with text. Default is 38. Should be less than infer_steps, usually near the end.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apg_start_step_general",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Starting step for Adaptive Projected Guidance (APG) for general image. Default is 5. Should be less than infer_steps, usually near the beginning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_rescale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Guidance rescale factor for steps without APG, 0.0 to 1.0. Default is 0.0 (no rescale)."
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
|
||||
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string")
|
||||
parser.add_argument("--image_size", type=int, nargs=2, default=[2048, 2048], help="image size, height and width")
|
||||
@@ -715,8 +733,11 @@ def generate_body(
|
||||
ocr_mask[0],
|
||||
args.guidance_scale,
|
||||
i,
|
||||
apg_start_step_ocr=args.apg_start_step_ocr,
|
||||
apg_start_step_general=args.apg_start_step_general,
|
||||
cfg_guider_ocr=cfg_guider_ocr,
|
||||
cfg_guider_general=cfg_guider_general,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
)
|
||||
|
||||
# ensure latents dtype is consistent
|
||||
|
||||
@@ -428,16 +428,52 @@ class AdaptiveProjectedGuidance:
|
||||
return pred
|
||||
|
||||
|
||||
def rescale_noise_cfg(guided_noise, conditional_noise, rescale_factor=0.0):
|
||||
"""
|
||||
Rescale guided noise prediction to prevent overexposure and improve image quality.
|
||||
|
||||
This implementation addresses the overexposure issue described in "Common Diffusion Noise
|
||||
Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf) (Section 3.4).
|
||||
The rescaling preserves the statistical properties of the conditional prediction while reducing artifacts.
|
||||
|
||||
Args:
|
||||
guided_noise (torch.Tensor): Noise prediction from classifier-free guidance.
|
||||
conditional_noise (torch.Tensor): Noise prediction from conditional model.
|
||||
rescale_factor (float): Interpolation factor between original and rescaled predictions.
|
||||
0.0 = no rescaling, 1.0 = full rescaling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Rescaled noise prediction with reduced overexposure.
|
||||
"""
|
||||
if rescale_factor == 0.0:
|
||||
return guided_noise
|
||||
|
||||
# Calculate standard deviation across spatial dimensions for both predictions
|
||||
spatial_dims = list(range(1, conditional_noise.ndim))
|
||||
conditional_std = conditional_noise.std(dim=spatial_dims, keepdim=True)
|
||||
guided_std = guided_noise.std(dim=spatial_dims, keepdim=True)
|
||||
|
||||
# Rescale guided noise to match conditional noise statistics
|
||||
std_ratio = conditional_std / guided_std
|
||||
rescaled_prediction = guided_noise * std_ratio
|
||||
|
||||
# Interpolate between original and rescaled predictions
|
||||
final_prediction = rescale_factor * rescaled_prediction + (1.0 - rescale_factor) * guided_noise
|
||||
|
||||
return final_prediction
|
||||
|
||||
|
||||
def apply_classifier_free_guidance(
|
||||
noise_pred_text: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
is_ocr: bool,
|
||||
guidance_scale: float,
|
||||
step: int,
|
||||
apg_start_step_ocr: int = 75,
|
||||
apg_start_step_general: int = 10,
|
||||
apg_start_step_ocr: int = 38,
|
||||
apg_start_step_general: int = 5,
|
||||
cfg_guider_ocr: AdaptiveProjectedGuidance = None,
|
||||
cfg_guider_general: AdaptiveProjectedGuidance = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Apply classifier-free guidance with OCR-aware APG for batch_size=1.
|
||||
@@ -471,6 +507,11 @@ def apply_classifier_free_guidance(
|
||||
if step <= apg_start_step:
|
||||
# Standard classifier-free guidance
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
|
||||
|
||||
# Initialize APG guider state
|
||||
_ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user