mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 08:21:46 +00:00
1、Implement cfg_trunc calculation directly using timesteps, without intermediate steps.
2、Deprecate and remove the guidance_scale parameter because it used in inference not train 3、Add inference command-line arguments --ct for cfg_trunc_ratio and --rc for renorm_cfg to control CFG truncation and renormalization during inference.
This commit is contained in:
@@ -1081,7 +1081,7 @@ class NextDiT(nn.Module):
|
||||
cap_feats: Tensor,
|
||||
cap_mask: Tensor,
|
||||
cfg_scale: float,
|
||||
cfg_trunc: int = 100,
|
||||
cfg_trunc: float = 0.25,
|
||||
renorm_cfg: float = 1.0,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -58,11 +58,13 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
guidance_scale = float(prompt_dict.get("scale", 3.5))
|
||||
sample_steps = int(prompt_dict.get("sample_steps", 38))
|
||||
cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25))
|
||||
renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0))
|
||||
seed = prompt_dict.get("seed", None)
|
||||
seed = int(seed) if seed is not None else None
|
||||
|
||||
# Create a key based on the parameters
|
||||
key = (width, height, guidance_scale, seed, sample_steps)
|
||||
key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg)
|
||||
|
||||
# Add the prompt_dict to the corresponding batch
|
||||
if key not in batches:
|
||||
@@ -268,6 +270,8 @@ def sample_image_inference(
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
|
||||
guidance_scale = float(prompt_dicts[0].get("scale", 3.5))
|
||||
cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25))
|
||||
renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0))
|
||||
sample_steps = int(prompt_dicts[0].get("sample_steps", 36))
|
||||
seed = prompt_dicts[0].get("seed", None)
|
||||
seed = int(seed) if seed is not None else None
|
||||
@@ -295,6 +299,8 @@ def sample_image_inference(
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {guidance_scale}")
|
||||
logger.info(f"trunc: {cfg_trunc_ratio}")
|
||||
logger.info(f"renorm: {renorm_cfg}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
|
||||
system_prompt = args.system_prompt or ""
|
||||
@@ -375,8 +381,9 @@ def sample_image_inference(
|
||||
uncond_hidden_states,
|
||||
uncond_attn_masks,
|
||||
timesteps=timesteps,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
cfg_trunc_ratio=cfg_trunc_ratio,
|
||||
renorm_cfg=renorm_cfg,
|
||||
)
|
||||
|
||||
# Latent to image
|
||||
@@ -550,10 +557,9 @@ def denoise(
|
||||
neg_txt: Tensor,
|
||||
neg_txt_mask: Tensor,
|
||||
timesteps: Union[List[float], torch.Tensor],
|
||||
num_inference_steps: int = 38,
|
||||
guidance_scale: float = 4.0,
|
||||
cfg_trunc_ratio: float = 1.0,
|
||||
cfg_normalization: bool = True,
|
||||
cfg_trunc_ratio: float = 0.25,
|
||||
renorm_cfg: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Denoise an image using the NextDiT model.
|
||||
@@ -578,21 +584,17 @@ def denoise(
|
||||
The guidance scale for the denoising process. Defaults to 4.0.
|
||||
cfg_trunc_ratio (float, optional):
|
||||
The ratio of the timestep interval to apply normalization-based guidance scale.
|
||||
cfg_normalization (bool, optional):
|
||||
Whether to apply normalization-based guidance scale.
|
||||
|
||||
renorm_cfg (float, optional):
|
||||
The factor to limit the maximum norm after guidance. Default: 1.0
|
||||
Returns:
|
||||
img (Tensor): Denoised latent tensor
|
||||
"""
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# compute whether apply classifier-free truncation on this timestep
|
||||
do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
|
||||
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
current_timestep = 1 - t / scheduler.config.num_train_timesteps
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(img.shape[0]).to(model.device)
|
||||
current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device)
|
||||
|
||||
noise_pred_cond = model(
|
||||
img,
|
||||
@@ -601,7 +603,8 @@ def denoise(
|
||||
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
|
||||
if not do_classifier_free_truncation:
|
||||
# compute whether to apply classifier-free guidance based on current timestep
|
||||
if current_timestep[0] < cfg_trunc_ratio:
|
||||
noise_pred_uncond = model(
|
||||
img,
|
||||
current_timestep,
|
||||
@@ -610,10 +613,12 @@ def denoise(
|
||||
)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
# apply normalization after classifier-free guidance
|
||||
if cfg_normalization:
|
||||
cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_pred = noise_pred * (cond_norm / noise_norm)
|
||||
if float(renorm_cfg) > 0.0:
|
||||
cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True)
|
||||
max_new_norm = cond_norm * float(renorm_cfg)
|
||||
noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True)
|
||||
if noise_norm >= max_new_norm:
|
||||
noise_pred = noise_pred * (max_new_norm / noise_norm)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
@@ -932,13 +937,6 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
" / Gemma2の最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=3.5,
|
||||
help="the NextDIT.1 dev variant is a guidance distilled model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
|
||||
|
||||
@@ -6188,6 +6188,16 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["controlnet_image"] = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"ct (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["cfg_trunc_ratio"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"rc (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["renorm_cfg"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.error(ex)
|
||||
|
||||
@@ -357,7 +357,6 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_guidance_scale"] = args.guidance_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
|
||||
Reference in New Issue
Block a user