mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add negative prompt for flux inference script
This commit is contained in:
@@ -11,6 +11,9 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 9, 2024:
|
||||||
|
Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used.
|
||||||
|
|
||||||
Sep 5, 2024 (update 1):
|
Sep 5, 2024 (update 1):
|
||||||
|
|
||||||
Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
|
Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`.
|
||||||
|
|||||||
@@ -71,22 +71,57 @@ def denoise(
|
|||||||
timesteps: list[float],
|
timesteps: list[float],
|
||||||
guidance: float = 4.0,
|
guidance: float = 4.0,
|
||||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
neg_txt: Optional[torch.Tensor] = None,
|
||||||
|
neg_vec: Optional[torch.Tensor] = None,
|
||||||
|
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
cfg_scale: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
|
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
# prepare classifier free guidance
|
||||||
|
if neg_txt is not None and neg_vec is not None:
|
||||||
|
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
|
||||||
|
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
|
||||||
|
b_txt = torch.cat([neg_txt, txt], dim=0)
|
||||||
|
b_vec = torch.cat([neg_vec, vec], dim=0)
|
||||||
|
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
|
||||||
|
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||||
|
else:
|
||||||
|
b_t5_attn_mask = None
|
||||||
|
else:
|
||||||
|
b_img_ids = img_ids
|
||||||
|
b_txt_ids = txt_ids
|
||||||
|
b_txt = txt
|
||||||
|
b_vec = vec
|
||||||
|
b_t5_attn_mask = t5_attn_mask
|
||||||
|
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
|
|
||||||
|
# classifier free guidance
|
||||||
|
if neg_txt is not None and neg_vec is not None:
|
||||||
|
b_img = torch.cat([img, img], dim=0)
|
||||||
|
else:
|
||||||
|
b_img = img
|
||||||
|
|
||||||
pred = model(
|
pred = model(
|
||||||
img=img,
|
img=b_img,
|
||||||
img_ids=img_ids,
|
img_ids=b_img_ids,
|
||||||
txt=txt,
|
txt=b_txt,
|
||||||
txt_ids=txt_ids,
|
txt_ids=b_txt_ids,
|
||||||
y=vec,
|
y=b_vec,
|
||||||
timesteps=t_vec,
|
timesteps=t_vec,
|
||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
txt_attention_mask=t5_attn_mask,
|
txt_attention_mask=b_t5_attn_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# classifier free guidance
|
||||||
|
if neg_txt is not None and neg_vec is not None:
|
||||||
|
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
|
||||||
|
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
|
||||||
|
|
||||||
img = img + (t_prev - t_curr) * pred
|
img = img + (t_prev - t_curr) * pred
|
||||||
|
|
||||||
return img
|
return img
|
||||||
@@ -106,19 +141,48 @@ def do_sample(
|
|||||||
is_schnell: bool,
|
is_schnell: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
flux_dtype: torch.dtype,
|
flux_dtype: torch.dtype,
|
||||||
|
neg_l_pooled: Optional[torch.Tensor] = None,
|
||||||
|
neg_t5_out: Optional[torch.Tensor] = None,
|
||||||
|
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
cfg_scale: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
logger.info(f"num_steps: {num_steps}")
|
||||||
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
||||||
|
|
||||||
# denoise initial noise
|
# denoise initial noise
|
||||||
if accelerator:
|
if accelerator:
|
||||||
with accelerator.autocast(), torch.no_grad():
|
with accelerator.autocast(), torch.no_grad():
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
|
model,
|
||||||
|
img,
|
||||||
|
img_ids,
|
||||||
|
t5_out,
|
||||||
|
txt_ids,
|
||||||
|
l_pooled,
|
||||||
|
timesteps,
|
||||||
|
guidance,
|
||||||
|
t5_attn_mask,
|
||||||
|
neg_t5_out,
|
||||||
|
neg_l_pooled,
|
||||||
|
neg_t5_attn_mask,
|
||||||
|
cfg_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
|
||||||
x = denoise(
|
x = denoise(
|
||||||
model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask
|
model,
|
||||||
|
img,
|
||||||
|
img_ids,
|
||||||
|
t5_out,
|
||||||
|
txt_ids,
|
||||||
|
l_pooled,
|
||||||
|
timesteps,
|
||||||
|
guidance,
|
||||||
|
t5_attn_mask,
|
||||||
|
neg_t5_out,
|
||||||
|
neg_l_pooled,
|
||||||
|
neg_t5_attn_mask,
|
||||||
|
cfg_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@@ -135,6 +199,8 @@ def generate_image(
|
|||||||
image_height: int,
|
image_height: int,
|
||||||
steps: Optional[int],
|
steps: Optional[int],
|
||||||
guidance: float,
|
guidance: float,
|
||||||
|
negative_prompt: Optional[str],
|
||||||
|
cfg_scale: float,
|
||||||
):
|
):
|
||||||
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
||||||
logger.info(f"Seed: {seed}")
|
logger.info(f"Seed: {seed}")
|
||||||
@@ -162,65 +228,73 @@ def generate_image(
|
|||||||
# txt2img only needs img_ids
|
# txt2img only needs img_ids
|
||||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
||||||
|
|
||||||
|
# prepare fp8 models
|
||||||
|
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
|
||||||
|
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
||||||
|
clip_l.to(clip_l_dtype) # fp8
|
||||||
|
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
||||||
|
clip_l.fp8_prepared = True
|
||||||
|
|
||||||
|
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
|
||||||
|
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
||||||
|
|
||||||
|
def prepare_fp8(text_encoder, target_dtype):
|
||||||
|
def forward_hook(module):
|
||||||
|
def forward(hidden_states):
|
||||||
|
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||||
|
hidden_linear = module.wi_1(hidden_states)
|
||||||
|
hidden_states = hidden_gelu * hidden_linear
|
||||||
|
hidden_states = module.dropout(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = module.wo(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
for module in text_encoder.modules():
|
||||||
|
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||||
|
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||||
|
module.to(target_dtype)
|
||||||
|
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||||
|
# print("set", module.__class__.__name__, "hooks")
|
||||||
|
module.forward = forward_hook(module)
|
||||||
|
|
||||||
|
t5xxl.to(t5xxl_dtype)
|
||||||
|
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
||||||
|
t5xxl.fp8_prepared = True
|
||||||
|
|
||||||
# prepare embeddings
|
# prepare embeddings
|
||||||
logger.info("Encoding prompts...")
|
logger.info("Encoding prompts...")
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
|
||||||
clip_l = clip_l.to(device)
|
clip_l = clip_l.to(device)
|
||||||
t5xxl = t5xxl.to(device)
|
t5xxl = t5xxl.to(device)
|
||||||
with torch.no_grad():
|
|
||||||
if is_fp8(clip_l_dtype):
|
|
||||||
param_itr = clip_l.parameters()
|
|
||||||
param_itr.__next__() # skip first
|
|
||||||
param_2nd = param_itr.__next__()
|
|
||||||
if param_2nd.dtype != clip_l_dtype:
|
|
||||||
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
|
||||||
clip_l.to(clip_l_dtype) # fp8
|
|
||||||
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
with accelerator.autocast():
|
def encode(prpt: str):
|
||||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||||
|
with torch.no_grad():
|
||||||
|
if is_fp8(clip_l_dtype):
|
||||||
|
with accelerator.autocast():
|
||||||
|
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||||
|
else:
|
||||||
|
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
||||||
|
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||||
|
|
||||||
if is_fp8(t5xxl_dtype):
|
if is_fp8(t5xxl_dtype):
|
||||||
if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"):
|
with accelerator.autocast():
|
||||||
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
||||||
|
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||||
|
)
|
||||||
|
return l_pooled, t5_out, txt_ids, t5_attn_mask
|
||||||
|
|
||||||
def prepare_fp8(text_encoder, target_dtype):
|
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
|
||||||
def forward_hook(module):
|
if negative_prompt:
|
||||||
def forward(hidden_states):
|
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
|
||||||
hidden_gelu = module.act(module.wi_0(hidden_states))
|
else:
|
||||||
hidden_linear = module.wi_1(hidden_states)
|
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
|
||||||
hidden_states = hidden_gelu * hidden_linear
|
|
||||||
hidden_states = module.dropout(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = module.wo(hidden_states)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
return forward
|
|
||||||
|
|
||||||
for module in text_encoder.modules():
|
|
||||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
|
||||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
|
||||||
module.to(target_dtype)
|
|
||||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
|
||||||
# print("set", module.__class__.__name__, "hooks")
|
|
||||||
module.forward = forward_hook(module)
|
|
||||||
|
|
||||||
text_encoder.fp8_prepared = True
|
|
||||||
|
|
||||||
t5xxl.to(t5xxl_dtype)
|
|
||||||
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
|
||||||
|
|
||||||
with accelerator.autocast():
|
|
||||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
|
||||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
|
|
||||||
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
|
||||||
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
|
|
||||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
|
||||||
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# NaN check
|
# NaN check
|
||||||
if torch.isnan(l_pooled).any():
|
if torch.isnan(l_pooled).any():
|
||||||
@@ -244,7 +318,23 @@ def generate_image(
|
|||||||
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
||||||
|
|
||||||
x = do_sample(
|
x = do_sample(
|
||||||
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype
|
accelerator,
|
||||||
|
model,
|
||||||
|
noise,
|
||||||
|
img_ids,
|
||||||
|
l_pooled,
|
||||||
|
t5_out,
|
||||||
|
txt_ids,
|
||||||
|
steps,
|
||||||
|
guidance,
|
||||||
|
t5_attn_mask,
|
||||||
|
is_schnell,
|
||||||
|
device,
|
||||||
|
flux_dtype,
|
||||||
|
neg_l_pooled,
|
||||||
|
neg_t5_out,
|
||||||
|
neg_t5_attn_mask,
|
||||||
|
cfg_scale,
|
||||||
)
|
)
|
||||||
if args.offload:
|
if args.offload:
|
||||||
model = model.cpu()
|
model = model.cpu()
|
||||||
@@ -307,6 +397,8 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--seed", type=int, default=None)
|
parser.add_argument("--seed", type=int, default=None)
|
||||||
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
|
||||||
parser.add_argument("--guidance", type=float, default=3.5)
|
parser.add_argument("--guidance", type=float, default=3.5)
|
||||||
|
parser.add_argument("--negative_prompt", type=str, default=None)
|
||||||
|
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
||||||
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora_weights",
|
"--lora_weights",
|
||||||
@@ -405,17 +497,32 @@ if __name__ == "__main__":
|
|||||||
lora_models.append(lora_model)
|
lora_models.append(lora_model)
|
||||||
|
|
||||||
if not args.interactive:
|
if not args.interactive:
|
||||||
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
|
generate_image(
|
||||||
|
model,
|
||||||
|
clip_l,
|
||||||
|
t5xxl,
|
||||||
|
ae,
|
||||||
|
args.prompt,
|
||||||
|
args.seed,
|
||||||
|
args.width,
|
||||||
|
args.height,
|
||||||
|
args.steps,
|
||||||
|
args.guidance,
|
||||||
|
args.negative_prompt,
|
||||||
|
args.cfg_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# loop for interactive
|
# loop for interactive
|
||||||
width = target_width
|
width = target_width
|
||||||
height = target_height
|
height = target_height
|
||||||
steps = None
|
steps = None
|
||||||
guidance = args.guidance
|
guidance = args.guidance
|
||||||
|
cfg_scale = args.cfg_scale
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print(
|
print(
|
||||||
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
||||||
|
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
||||||
)
|
)
|
||||||
prompt = input()
|
prompt = input()
|
||||||
if prompt == "":
|
if prompt == "":
|
||||||
@@ -425,26 +532,36 @@ if __name__ == "__main__":
|
|||||||
options = prompt.split("--")
|
options = prompt.split("--")
|
||||||
prompt = options[0].strip()
|
prompt = options[0].strip()
|
||||||
seed = None
|
seed = None
|
||||||
|
negative_prompt = None
|
||||||
for opt in options[1:]:
|
for opt in options[1:]:
|
||||||
opt = opt.strip()
|
try:
|
||||||
if opt.startswith("w"):
|
opt = opt.strip()
|
||||||
width = int(opt[1:].strip())
|
if opt.startswith("w"):
|
||||||
elif opt.startswith("h"):
|
width = int(opt[1:].strip())
|
||||||
height = int(opt[1:].strip())
|
elif opt.startswith("h"):
|
||||||
elif opt.startswith("s"):
|
height = int(opt[1:].strip())
|
||||||
steps = int(opt[1:].strip())
|
elif opt.startswith("s"):
|
||||||
elif opt.startswith("d"):
|
steps = int(opt[1:].strip())
|
||||||
seed = int(opt[1:].strip())
|
elif opt.startswith("d"):
|
||||||
elif opt.startswith("g"):
|
seed = int(opt[1:].strip())
|
||||||
guidance = float(opt[1:].strip())
|
elif opt.startswith("g"):
|
||||||
elif opt.startswith("m"):
|
guidance = float(opt[1:].strip())
|
||||||
mutipliers = opt[1:].strip().split(",")
|
elif opt.startswith("m"):
|
||||||
if len(mutipliers) != len(lora_models):
|
mutipliers = opt[1:].strip().split(",")
|
||||||
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
if len(mutipliers) != len(lora_models):
|
||||||
continue
|
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||||
for i, lora_model in enumerate(lora_models):
|
continue
|
||||||
lora_model.set_multiplier(float(mutipliers[i]))
|
for i, lora_model in enumerate(lora_models):
|
||||||
|
lora_model.set_multiplier(float(mutipliers[i]))
|
||||||
|
elif opt.startswith("n"):
|
||||||
|
negative_prompt = opt[1:].strip()
|
||||||
|
if negative_prompt == "-":
|
||||||
|
negative_prompt = ""
|
||||||
|
elif opt.startswith("c"):
|
||||||
|
cfg_scale = float(opt[1:].strip())
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Invalid option: {opt}, {e}")
|
||||||
|
|
||||||
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance)
|
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
|
||||||
|
|
||||||
logger.info("Done!")
|
logger.info("Done!")
|
||||||
|
|||||||
Reference in New Issue
Block a user