mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add CFG to FLUX.1 sample image
This commit is contained in:
@@ -40,7 +40,7 @@ def sample_images(
|
|||||||
text_encoders,
|
text_encoders,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement=None,
|
prompt_replacement=None,
|
||||||
controlnet=None
|
controlnet=None,
|
||||||
):
|
):
|
||||||
if steps == 0:
|
if steps == 0:
|
||||||
if not args.sample_at_first:
|
if not args.sample_at_first:
|
||||||
@@ -101,7 +101,7 @@ def sample_images(
|
|||||||
steps,
|
steps,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
controlnet
|
controlnet,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||||
@@ -125,7 +125,7 @@ def sample_images(
|
|||||||
steps,
|
steps,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
controlnet
|
controlnet,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
@@ -147,14 +147,14 @@ def sample_image_inference(
|
|||||||
steps,
|
steps,
|
||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
prompt_replacement,
|
prompt_replacement,
|
||||||
controlnet
|
controlnet,
|
||||||
):
|
):
|
||||||
assert isinstance(prompt_dict, dict)
|
assert isinstance(prompt_dict, dict)
|
||||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
negative_prompt = prompt_dict.get("negative_prompt")
|
||||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||||
width = prompt_dict.get("width", 512)
|
width = prompt_dict.get("width", 512)
|
||||||
height = prompt_dict.get("height", 512)
|
height = prompt_dict.get("height", 512)
|
||||||
scale = prompt_dict.get("scale", 3.5)
|
scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance
|
||||||
seed = prompt_dict.get("seed")
|
seed = prompt_dict.get("seed")
|
||||||
controlnet_image = prompt_dict.get("controlnet_image")
|
controlnet_image = prompt_dict.get("controlnet_image")
|
||||||
prompt: str = prompt_dict.get("prompt", "")
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
@@ -162,8 +162,8 @@ def sample_image_inference(
|
|||||||
|
|
||||||
if prompt_replacement is not None:
|
if prompt_replacement is not None:
|
||||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
# if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@@ -173,16 +173,18 @@ def sample_image_inference(
|
|||||||
torch.seed()
|
torch.seed()
|
||||||
torch.cuda.seed()
|
torch.cuda.seed()
|
||||||
|
|
||||||
# if negative_prompt is None:
|
if negative_prompt is None:
|
||||||
# negative_prompt = ""
|
negative_prompt = ""
|
||||||
height = max(64, height - height % 16) # round to divisible by 16
|
height = max(64, height - height % 16) # round to divisible by 16
|
||||||
width = max(64, width - width % 16) # round to divisible by 16
|
width = max(64, width - width % 16) # round to divisible by 16
|
||||||
logger.info(f"prompt: {prompt}")
|
logger.info(f"prompt: {prompt}")
|
||||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
if scale != 1.0:
|
||||||
|
logger.info(f"negative_prompt: {negative_prompt}")
|
||||||
logger.info(f"height: {height}")
|
logger.info(f"height: {height}")
|
||||||
logger.info(f"width: {width}")
|
logger.info(f"width: {width}")
|
||||||
logger.info(f"sample_steps: {sample_steps}")
|
logger.info(f"sample_steps: {sample_steps}")
|
||||||
logger.info(f"scale: {scale}")
|
if scale != 1.0:
|
||||||
|
logger.info(f"scale: {scale}")
|
||||||
# logger.info(f"sample_sampler: {sampler_name}")
|
# logger.info(f"sample_sampler: {sampler_name}")
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
logger.info(f"seed: {seed}")
|
logger.info(f"seed: {seed}")
|
||||||
@@ -191,26 +193,37 @@ def sample_image_inference(
|
|||||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
|
|
||||||
text_encoder_conds = []
|
def encode_prompt(prpt):
|
||||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
text_encoder_conds = []
|
||||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
||||||
if text_encoders is not None:
|
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
||||||
print(f"Encoding prompt: {prompt}")
|
if text_encoders is not None:
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
print(f"Encoding prompt: {prpt}")
|
||||||
# strategy has apply_t5_attn_mask option
|
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
# strategy has apply_t5_attn_mask option
|
||||||
|
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||||
|
|
||||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||||
if len(text_encoder_conds) == 0:
|
if len(text_encoder_conds) == 0:
|
||||||
text_encoder_conds = encoded_text_encoder_conds
|
text_encoder_conds = encoded_text_encoder_conds
|
||||||
else:
|
else:
|
||||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||||
for i in range(len(encoded_text_encoder_conds)):
|
for i in range(len(encoded_text_encoder_conds)):
|
||||||
if encoded_text_encoder_conds[i] is not None:
|
if encoded_text_encoder_conds[i] is not None:
|
||||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||||
|
return text_encoder_conds
|
||||||
|
|
||||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
|
||||||
|
# encode negative prompts
|
||||||
|
if scale != 1.0:
|
||||||
|
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
|
||||||
|
neg_t5_attn_mask = (
|
||||||
|
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
|
||||||
|
)
|
||||||
|
neg_cond = (scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
|
||||||
|
else:
|
||||||
|
neg_cond = None
|
||||||
|
|
||||||
# sample image
|
# sample image
|
||||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||||
@@ -235,7 +248,20 @@ def sample_image_inference(
|
|||||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||||
|
|
||||||
with accelerator.autocast(), torch.no_grad():
|
with accelerator.autocast(), torch.no_grad():
|
||||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
x = denoise(
|
||||||
|
flux,
|
||||||
|
noise,
|
||||||
|
img_ids,
|
||||||
|
t5_out,
|
||||||
|
txt_ids,
|
||||||
|
l_pooled,
|
||||||
|
timesteps=timesteps,
|
||||||
|
guidance=scale,
|
||||||
|
t5_attn_mask=t5_attn_mask,
|
||||||
|
controlnet=controlnet,
|
||||||
|
controlnet_img=controlnet_image,
|
||||||
|
neg_cond=neg_cond,
|
||||||
|
)
|
||||||
|
|
||||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||||
|
|
||||||
@@ -305,22 +331,24 @@ def denoise(
|
|||||||
model: flux_models.Flux,
|
model: flux_models.Flux,
|
||||||
img: torch.Tensor,
|
img: torch.Tensor,
|
||||||
img_ids: torch.Tensor,
|
img_ids: torch.Tensor,
|
||||||
txt: torch.Tensor,
|
txt: torch.Tensor, # t5_out
|
||||||
txt_ids: torch.Tensor,
|
txt_ids: torch.Tensor,
|
||||||
vec: torch.Tensor,
|
vec: torch.Tensor, # l_pooled
|
||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||||
controlnet_img: Optional[torch.Tensor] = None,
|
controlnet_img: Optional[torch.Tensor] = None,
|
||||||
|
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||||
):
|
):
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
do_cfg = neg_cond is not None
|
||||||
|
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
model.prepare_block_swap_before_forward()
|
model.prepare_block_swap_before_forward()
|
||||||
|
|
||||||
if controlnet is not None:
|
if controlnet is not None:
|
||||||
block_samples, block_single_samples = controlnet(
|
block_samples, block_single_samples = controlnet(
|
||||||
img=img,
|
img=img,
|
||||||
@@ -336,20 +364,48 @@ def denoise(
|
|||||||
else:
|
else:
|
||||||
block_samples = None
|
block_samples = None
|
||||||
block_single_samples = None
|
block_single_samples = None
|
||||||
pred = model(
|
|
||||||
img=img,
|
|
||||||
img_ids=img_ids,
|
|
||||||
txt=txt,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=vec,
|
|
||||||
block_controlnet_hidden_states=block_samples,
|
|
||||||
block_controlnet_single_hidden_states=block_single_samples,
|
|
||||||
timesteps=t_vec,
|
|
||||||
guidance=guidance_vec,
|
|
||||||
txt_attention_mask=t5_attn_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
if not do_cfg:
|
||||||
|
pred = model(
|
||||||
|
img=img,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=txt,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=vec,
|
||||||
|
block_controlnet_hidden_states=block_samples,
|
||||||
|
block_controlnet_single_hidden_states=block_single_samples,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
else:
|
||||||
|
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
|
||||||
|
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||||
|
|
||||||
|
# TODO is it ok to use the same block samples for both cond and uncond?
|
||||||
|
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
|
||||||
|
block_single_samples = (
|
||||||
|
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
nc_c_pred = model(
|
||||||
|
img=torch.cat([img, img], dim=0),
|
||||||
|
img_ids=torch.cat([img_ids, img_ids], dim=0),
|
||||||
|
txt=torch.cat([neg_t5_out, txt], dim=0),
|
||||||
|
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
|
||||||
|
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||||
|
block_controlnet_hidden_states=block_samples,
|
||||||
|
block_controlnet_single_hidden_states=block_single_samples,
|
||||||
|
timesteps=t_vec,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=nc_c_t5_attn_mask,
|
||||||
|
)
|
||||||
|
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||||
|
pred = neg_pred + (pred - neg_pred) * cfg_scale
|
||||||
|
|
||||||
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
model.prepare_block_swap_before_forward()
|
model.prepare_block_swap_before_forward()
|
||||||
return img
|
return img
|
||||||
@@ -567,7 +623,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--controlnet_model_name_or_path",
|
"--controlnet_model_name_or_path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--t5xxl_max_token_length",
|
"--t5xxl_max_token_length",
|
||||||
|
|||||||
Reference in New Issue
Block a user