fix: revert default emb guidance scale and CFG scale for FLUX.1 sampling

This commit is contained in:
Kohya S
2025-04-27 22:50:27 +09:00
parent 13296ae93b
commit fd3a445769

View File

@@ -154,8 +154,9 @@ def sample_image_inference(
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
guidance_scale = prompt_dict.get("guidance_scale", args.guidance_scale)
scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance
# TODO refactor variable names
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
emb_guidance_scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
@@ -179,16 +180,16 @@ def sample_image_inference(
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
if scale != 1.0:
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"guidance_scale: {guidance_scale}")
if scale != 1.0:
logger.info(f"scale: {scale}")
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
if cfg_scale != 1.0:
logger.info(f"CFG scale: {cfg_scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
@@ -220,12 +221,12 @@ def sample_image_inference(
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
# encode negative prompts
if scale != 1.0:
if cfg_scale != 1.0:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
neg_t5_attn_mask = (
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
)
neg_cond = (scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
else:
neg_cond = None
@@ -260,7 +261,7 @@ def sample_image_inference(
txt_ids,
l_pooled,
timesteps=timesteps,
guidance=guidance_scale,
guidance=emb_guidance_scale,
t5_attn_mask=t5_attn_mask,
controlnet=controlnet,
controlnet_img=controlnet_image,