From fc772affbe4345c8e0d14eb53ebc883f8c5a576f Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 24 Feb 2025 14:10:24 +0800 Subject: [PATCH] =?UTF-8?q?1=E3=80=81Implement=20cfg=5Ftrunc=20calculation?= =?UTF-8?q?=20directly=20using=20timesteps,=20without=20intermediate=20ste?= =?UTF-8?q?ps.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 2、Deprecate and remove the guidance_scale parameter because it used in inference not train 3、Add inference command-line arguments --ct for cfg_trunc_ratio and --rc for renorm_cfg to control CFG truncation and renormalization during inference. --- library/lumina_models.py | 2 +- library/lumina_train_util.py | 46 +++++++++++++++++------------------- library/train_util.py | 10 ++++++++ lumina_train_network.py | 1 - 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index d86a9cb2..1a441a69 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -1081,7 +1081,7 @@ class NextDiT(nn.Module): cap_feats: Tensor, cap_mask: Tensor, cfg_scale: float, - cfg_trunc: int = 100, + cfg_trunc: float = 0.25, renorm_cfg: float = 1.0, ): """ diff --git a/library/lumina_train_util.py b/library/lumina_train_util.py index 87f7ba36..f54b202d 100644 --- a/library/lumina_train_util.py +++ b/library/lumina_train_util.py @@ -58,11 +58,13 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N width = max(64, width - width % 8) # round to divisible by 8 guidance_scale = float(prompt_dict.get("scale", 3.5)) sample_steps = int(prompt_dict.get("sample_steps", 38)) + cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0)) seed = prompt_dict.get("seed", None) seed = int(seed) if seed is not None else None # Create a key based on the parameters - key = (width, height, guidance_scale, seed, sample_steps) + key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg) # Add the prompt_dict to the corresponding batch if key not in batches: @@ -268,6 +270,8 @@ def sample_image_inference( width = max(64, width - width % 8) # round to divisible by 8 guidance_scale = float(prompt_dicts[0].get("scale", 3.5)) + cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25)) + renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0)) sample_steps = int(prompt_dicts[0].get("sample_steps", 36)) seed = prompt_dicts[0].get("seed", None) seed = int(seed) if seed is not None else None @@ -295,6 +299,8 @@ def sample_image_inference( logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") logger.info(f"scale: {guidance_scale}") + logger.info(f"trunc: {cfg_trunc_ratio}") + logger.info(f"renorm: {renorm_cfg}") # logger.info(f"sample_sampler: {sampler_name}") system_prompt = args.system_prompt or "" @@ -375,8 +381,9 @@ def sample_image_inference( uncond_hidden_states, uncond_attn_masks, timesteps=timesteps, - num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, + cfg_trunc_ratio=cfg_trunc_ratio, + renorm_cfg=renorm_cfg, ) # Latent to image @@ -550,10 +557,9 @@ def denoise( neg_txt: Tensor, neg_txt_mask: Tensor, timesteps: Union[List[float], torch.Tensor], - num_inference_steps: int = 38, guidance_scale: float = 4.0, - cfg_trunc_ratio: float = 1.0, - cfg_normalization: bool = True, + cfg_trunc_ratio: float = 0.25, + renorm_cfg: float = 1.0, ): """ Denoise an image using the NextDiT model. @@ -578,21 +584,17 @@ def denoise( The guidance scale for the denoising process. Defaults to 4.0. cfg_trunc_ratio (float, optional): The ratio of the timestep interval to apply normalization-based guidance scale. - cfg_normalization (bool, optional): - Whether to apply normalization-based guidance scale. - + renorm_cfg (float, optional): + The factor to limit the maximum norm after guidance. Default: 1.0 Returns: img (Tensor): Denoised latent tensor """ for i, t in enumerate(tqdm(timesteps)): - # compute whether apply classifier-free truncation on this timestep - do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio - # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - t / scheduler.config.num_train_timesteps # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep.expand(img.shape[0]).to(model.device) + current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device) noise_pred_cond = model( img, @@ -601,7 +603,8 @@ def denoise( cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask ) - if not do_classifier_free_truncation: + # compute whether to apply classifier-free guidance based on current timestep + if current_timestep[0] < cfg_trunc_ratio: noise_pred_uncond = model( img, current_timestep, @@ -610,10 +613,12 @@ def denoise( ) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # apply normalization after classifier-free guidance - if cfg_normalization: - cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) - noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_pred = noise_pred * (cond_norm / noise_norm) + if float(renorm_cfg) > 0.0: + cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True) + max_new_norm = cond_norm * float(renorm_cfg) + noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True) + if noise_norm >= max_new_norm: + noise_pred = noise_pred * (max_new_norm / noise_norm) else: noise_pred = noise_pred_cond @@ -932,13 +937,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser): " / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512", ) - parser.add_argument( - "--guidance_scale", - type=float, - default=3.5, - help="the NextDIT.1 dev variant is a guidance distilled model", - ) - parser.add_argument( "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"], diff --git a/library/train_util.py b/library/train_util.py index ded23f41..18aceaf7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6188,6 +6188,16 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["controlnet_image"] = m.group(1) continue + m = re.match(r"ct (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["cfg_trunc_ratio"] = float(m.group(1)) + continue + + m = re.match(r"rc (.+)", parg, re.IGNORECASE) + if m: + prompt_dict["renorm_cfg"] = float(m.group(1)) + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(ex) diff --git a/lumina_train_network.py b/lumina_train_network.py index adbf834c..0fd4da6b 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -357,7 +357,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): metadata["ss_logit_mean"] = args.logit_mean metadata["ss_logit_std"] = args.logit_std metadata["ss_mode_scale"] = args.mode_scale - metadata["ss_guidance_scale"] = args.guidance_scale metadata["ss_timestep_sampling"] = args.timestep_sampling metadata["ss_sigmoid_scale"] = args.sigmoid_scale metadata["ss_model_prediction_type"] = args.model_prediction_type