add CFG to FLUX.1 sample image

This commit is contained in:
Kohya S
2025-01-05 22:28:51 +09:00
parent e89653975d
commit a9c5aa1f93

View File

@@ -40,7 +40,7 @@ def sample_images(
text_encoders, text_encoders,
sample_prompts_te_outputs, sample_prompts_te_outputs,
prompt_replacement=None, prompt_replacement=None,
controlnet=None controlnet=None,
): ):
if steps == 0: if steps == 0:
if not args.sample_at_first: if not args.sample_at_first:
@@ -101,7 +101,7 @@ def sample_images(
steps, steps,
sample_prompts_te_outputs, sample_prompts_te_outputs,
prompt_replacement, prompt_replacement,
controlnet controlnet,
) )
else: else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
@@ -125,7 +125,7 @@ def sample_images(
steps, steps,
sample_prompts_te_outputs, sample_prompts_te_outputs,
prompt_replacement, prompt_replacement,
controlnet controlnet,
) )
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
@@ -147,14 +147,14 @@ def sample_image_inference(
steps, steps,
sample_prompts_te_outputs, sample_prompts_te_outputs,
prompt_replacement, prompt_replacement,
controlnet controlnet,
): ):
assert isinstance(prompt_dict, dict) assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt") negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20) sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512) width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512) height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5) scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance
seed = prompt_dict.get("seed") seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image") controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "") prompt: str = prompt_dict.get("prompt", "")
@@ -162,8 +162,8 @@ def sample_image_inference(
if prompt_replacement is not None: if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None: if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None: if seed is not None:
torch.manual_seed(seed) torch.manual_seed(seed)
@@ -173,15 +173,17 @@ def sample_image_inference(
torch.seed() torch.seed()
torch.cuda.seed() torch.cuda.seed()
# if negative_prompt is None: if negative_prompt is None:
# negative_prompt = "" negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16 height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}") logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}") if scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}") logger.info(f"height: {height}")
logger.info(f"width: {width}") logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}") logger.info(f"sample_steps: {sample_steps}")
if scale != 1.0:
logger.info(f"scale: {scale}") logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}") # logger.info(f"sample_sampler: {sampler_name}")
if seed is not None: if seed is not None:
@@ -191,13 +193,14 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
def encode_prompt(prpt):
text_encoder_conds = [] text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prompt] text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prompt}") print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None: if text_encoders is not None:
print(f"Encoding prompt: {prompt}") print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt) tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option # strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
@@ -209,8 +212,18 @@ def sample_image_inference(
for i in range(len(encoded_text_encoder_conds)): for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None: if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i] text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
# encode negative prompts
if scale != 1.0:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
neg_t5_attn_mask = (
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
)
neg_cond = (scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
else:
neg_cond = None
# sample image # sample image
weight_dtype = ae.dtype # TOFO give dtype as argument weight_dtype = ae.dtype # TOFO give dtype as argument
@@ -235,7 +248,20 @@ def sample_image_inference(
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad(): with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) x = denoise(
flux,
noise,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps=timesteps,
guidance=scale,
t5_attn_mask=t5_attn_mask,
controlnet=controlnet,
controlnet_img=controlnet_image,
neg_cond=neg_cond,
)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
@@ -305,22 +331,24 @@ def denoise(
model: flux_models.Flux, model: flux_models.Flux,
img: torch.Tensor, img: torch.Tensor,
img_ids: torch.Tensor, img_ids: torch.Tensor,
txt: torch.Tensor, txt: torch.Tensor, # t5_out
txt_ids: torch.Tensor, txt_ids: torch.Tensor,
vec: torch.Tensor, vec: torch.Tensor, # l_pooled
timesteps: list[float], timesteps: list[float],
guidance: float = 4.0, guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None, t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None, controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None, controlnet_img: Optional[torch.Tensor] = None,
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
): ):
# this is ignored for schnell # this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
do_cfg = neg_cond is not None
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward() model.prepare_block_swap_before_forward()
if controlnet is not None: if controlnet is not None:
block_samples, block_single_samples = controlnet( block_samples, block_single_samples = controlnet(
img=img, img=img,
@@ -336,6 +364,8 @@ def denoise(
else: else:
block_samples = None block_samples = None
block_single_samples = None block_single_samples = None
if not do_cfg:
pred = model( pred = model(
img=img, img=img,
img_ids=img_ids, img_ids=img_ids,
@@ -349,6 +379,32 @@ def denoise(
txt_attention_mask=t5_attn_mask, txt_attention_mask=t5_attn_mask,
) )
img = img + (t_prev - t_curr) * pred
else:
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
# TODO is it ok to use the same block samples for both cond and uncond?
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
block_single_samples = (
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
)
nc_c_pred = model(
img=torch.cat([img, img], dim=0),
img_ids=torch.cat([img_ids, img_ids], dim=0),
txt=torch.cat([neg_t5_out, txt], dim=0),
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
y=torch.cat([neg_l_pooled, vec], dim=0),
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=nc_c_t5_attn_mask,
)
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
pred = neg_pred + (pred - neg_pred) * cfg_scale
img = img + (t_prev - t_curr) * pred img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward() model.prepare_block_swap_before_forward()
@@ -567,7 +623,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
"--controlnet_model_name_or_path", "--controlnet_model_name_or_path",
type=str, type=str,
default=None, default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors" help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors",
) )
parser.add_argument( parser.add_argument(
"--t5xxl_max_token_length", "--t5xxl_max_token_length",