1、Implement cfg_trunc calculation directly using timesteps, without intermediate steps.

2、Deprecate and remove the guidance_scale parameter because it used in inference not train

3、Add inference command-line arguments --ct for cfg_trunc_ratio and --rc for renorm_cfg to control CFG truncation and renormalization during inference.
This commit is contained in:
sdbds
2025-02-24 14:10:24 +08:00
parent 653621de57
commit fc772affbe
4 changed files with 33 additions and 26 deletions

View File

@@ -1081,7 +1081,7 @@ class NextDiT(nn.Module):
cap_feats: Tensor,
cap_mask: Tensor,
cfg_scale: float,
cfg_trunc: int = 100,
cfg_trunc: float = 0.25,
renorm_cfg: float = 1.0,
):
"""

View File

@@ -58,11 +58,13 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N
width = max(64, width - width % 8) # round to divisible by 8
guidance_scale = float(prompt_dict.get("scale", 3.5))
sample_steps = int(prompt_dict.get("sample_steps", 38))
cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25))
renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0))
seed = prompt_dict.get("seed", None)
seed = int(seed) if seed is not None else None
# Create a key based on the parameters
key = (width, height, guidance_scale, seed, sample_steps)
key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg)
# Add the prompt_dict to the corresponding batch
if key not in batches:
@@ -268,6 +270,8 @@ def sample_image_inference(
width = max(64, width - width % 8) # round to divisible by 8
guidance_scale = float(prompt_dicts[0].get("scale", 3.5))
cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25))
renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0))
sample_steps = int(prompt_dicts[0].get("sample_steps", 36))
seed = prompt_dicts[0].get("seed", None)
seed = int(seed) if seed is not None else None
@@ -295,6 +299,8 @@ def sample_image_inference(
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {guidance_scale}")
logger.info(f"trunc: {cfg_trunc_ratio}")
logger.info(f"renorm: {renorm_cfg}")
# logger.info(f"sample_sampler: {sampler_name}")
system_prompt = args.system_prompt or ""
@@ -375,8 +381,9 @@ def sample_image_inference(
uncond_hidden_states,
uncond_attn_masks,
timesteps=timesteps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
cfg_trunc_ratio=cfg_trunc_ratio,
renorm_cfg=renorm_cfg,
)
# Latent to image
@@ -550,10 +557,9 @@ def denoise(
neg_txt: Tensor,
neg_txt_mask: Tensor,
timesteps: Union[List[float], torch.Tensor],
num_inference_steps: int = 38,
guidance_scale: float = 4.0,
cfg_trunc_ratio: float = 1.0,
cfg_normalization: bool = True,
cfg_trunc_ratio: float = 0.25,
renorm_cfg: float = 1.0,
):
"""
Denoise an image using the NextDiT model.
@@ -578,21 +584,17 @@ def denoise(
The guidance scale for the denoising process. Defaults to 4.0.
cfg_trunc_ratio (float, optional):
The ratio of the timestep interval to apply normalization-based guidance scale.
cfg_normalization (bool, optional):
Whether to apply normalization-based guidance scale.
renorm_cfg (float, optional):
The factor to limit the maximum norm after guidance. Default: 1.0
Returns:
img (Tensor): Denoised latent tensor
"""
for i, t in enumerate(tqdm(timesteps)):
# compute whether apply classifier-free truncation on this timestep
do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / scheduler.config.num_train_timesteps
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(img.shape[0]).to(model.device)
current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device)
noise_pred_cond = model(
img,
@@ -601,7 +603,8 @@ def denoise(
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
if not do_classifier_free_truncation:
# compute whether to apply classifier-free guidance based on current timestep
if current_timestep[0] < cfg_trunc_ratio:
noise_pred_uncond = model(
img,
current_timestep,
@@ -610,10 +613,12 @@ def denoise(
)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# apply normalization after classifier-free guidance
if cfg_normalization:
cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_pred = noise_pred * (cond_norm / noise_norm)
if float(renorm_cfg) > 0.0:
cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True)
max_new_norm = cond_norm * float(renorm_cfg)
noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True)
if noise_norm >= max_new_norm:
noise_pred = noise_pred * (max_new_norm / noise_norm)
else:
noise_pred = noise_pred_cond
@@ -932,13 +937,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
" / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the NextDIT.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],

View File

@@ -6188,6 +6188,16 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["controlnet_image"] = m.group(1)
continue
m = re.match(r"ct (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["cfg_trunc_ratio"] = float(m.group(1))
continue
m = re.match(r"rc (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)

View File

@@ -357,7 +357,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_guidance_scale"] = args.guidance_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type