diff --git a/README.md b/README.md index 2f010f49..126516f9 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,9 @@ The command to install PyTorch is as follows: ### Recent Updates +Sep 9, 2024: +Added `--negative_prompt` and `--cfg_scale` to `flux_minimal_inference.py`. Negative prompts can be used. + Sep 5, 2024 (update 1): Added `--cpu_offload_checkpointing` option to LoRA training script. Offloads gradient checkpointing to CPU. This reduces up to 1GB of VRAM usage but slows down the training by about 15%. Cannot be used with `--split_mode`. diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 1c194e7c..de607c52 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -71,22 +71,57 @@ def denoise( timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, + neg_txt: Optional[torch.Tensor] = None, + neg_vec: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): # this is ignored for schnell + logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}") guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + # prepare classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img_ids = torch.cat([img_ids, img_ids], dim=0) + b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0) + b_txt = torch.cat([neg_txt, txt], dim=0) + b_vec = torch.cat([neg_vec, vec], dim=0) + if t5_attn_mask is not None and neg_t5_attn_mask is not None: + b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + else: + b_t5_attn_mask = None + else: + b_img_ids = img_ids + b_txt_ids = txt_ids + b_txt = txt + b_vec = vec + b_t5_attn_mask = t5_attn_mask + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): - t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + b_img = torch.cat([img, img], dim=0) + else: + b_img = img + pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + img=b_img, + img_ids=b_img_ids, + txt=b_txt, + txt_ids=b_txt_ids, + y=b_vec, timesteps=t_vec, guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, + txt_attention_mask=b_t5_attn_mask, ) + # classifier free guidance + if neg_txt is not None and neg_vec is not None: + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred return img @@ -106,19 +141,48 @@ def do_sample( is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, + neg_l_pooled: Optional[torch.Tensor] = None, + neg_t5_out: Optional[torch.Tensor] = None, + neg_t5_attn_mask: Optional[torch.Tensor] = None, + cfg_scale: Optional[float] = None, ): + logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) else: with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad(): x = denoise( - model, img, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=guidance, t5_attn_mask=t5_attn_mask + model, + img, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps, + guidance, + t5_attn_mask, + neg_t5_out, + neg_l_pooled, + neg_t5_attn_mask, + cfg_scale, ) return x @@ -135,6 +199,8 @@ def generate_image( image_height: int, steps: Optional[int], guidance: float, + negative_prompt: Optional[str], + cfg_scale: float, ): seed = seed if seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") @@ -162,65 +228,73 @@ def generate_image( # txt2img only needs img_ids img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) + # prepare fp8 models + if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared): + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + clip_l.fp8_prepared = True + + if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + t5xxl.fp8_prepared = True + # prepare embeddings logger.info("Encoding prompts...") - tokens_and_masks = tokenize_strategy.tokenize(prompt) clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) - with torch.no_grad(): - if is_fp8(clip_l_dtype): - param_itr = clip_l.parameters() - param_itr.__next__() # skip first - param_2nd = param_itr.__next__() - if param_2nd.dtype != clip_l_dtype: - logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") - clip_l.to(clip_l_dtype) # fp8 - clip_l.text_model.embeddings.to(dtype=torch.bfloat16) - with accelerator.autocast(): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + def encode(prpt: str): + tokens_and_masks = tokenize_strategy.tokenize(prpt) + with torch.no_grad(): + if is_fp8(clip_l_dtype): + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + else: + with torch.autocast(device_type=device.type, dtype=clip_l_dtype): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - if is_fp8(t5xxl_dtype): - if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): - logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + if is_fp8(t5xxl_dtype): + with accelerator.autocast(): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + else: + with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): + _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( + tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask + ) + return l_pooled, t5_out, txt_ids, t5_attn_mask - def prepare_fp8(text_encoder, target_dtype): - def forward_hook(module): - def forward(hidden_states): - hidden_gelu = module.act(module.wi_0(hidden_states)) - hidden_linear = module.wi_1(hidden_states) - hidden_states = hidden_gelu * hidden_linear - hidden_states = module.dropout(hidden_states) - - hidden_states = module.wo(hidden_states) - return hidden_states - - return forward - - for module in text_encoder.modules(): - if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: - # print("set", module.__class__.__name__, "to", target_dtype) - module.to(target_dtype) - if module.__class__.__name__ in ["T5DenseGatedActDense"]: - # print("set", module.__class__.__name__, "hooks") - module.forward = forward_hook(module) - - text_encoder.fp8_prepared = True - - t5xxl.to(t5xxl_dtype) - prepare_fp8(t5xxl.encoder, torch.bfloat16) - - with accelerator.autocast(): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) - else: - with torch.autocast(device_type=device.type, dtype=clip_l_dtype): - l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) - with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): - _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( - tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask - ) + l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) + if negative_prompt: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt) + else: + neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None # NaN check if torch.isnan(l_pooled).any(): @@ -244,7 +318,23 @@ def generate_image( t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, t5_attn_mask, is_schnell, device, flux_dtype + accelerator, + model, + noise, + img_ids, + l_pooled, + t5_out, + txt_ids, + steps, + guidance, + t5_attn_mask, + is_schnell, + device, + flux_dtype, + neg_l_pooled, + neg_t5_out, + neg_t5_attn_mask, + cfg_scale, ) if args.offload: model = model.cpu() @@ -307,6 +397,8 @@ if __name__ == "__main__": parser.add_argument("--seed", type=int, default=None) parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev") parser.add_argument("--guidance", type=float, default=3.5) + parser.add_argument("--negative_prompt", type=str, default=None) + parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument( "--lora_weights", @@ -403,19 +495,34 @@ if __name__ == "__main__": lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: - generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) + generate_image( + model, + clip_l, + t5xxl, + ae, + args.prompt, + args.seed, + args.width, + args.height, + args.steps, + args.guidance, + args.negative_prompt, + args.cfg_scale, + ) else: # loop for interactive width = target_width height = target_height steps = None guidance = args.guidance + cfg_scale = args.cfg_scale while True: print( "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + " --n , `-` for empty negative prompt --c " ) prompt = input() if prompt == "": @@ -425,26 +532,36 @@ if __name__ == "__main__": options = prompt.split("--") prompt = options[0].strip() seed = None + negative_prompt = None for opt in options[1:]: - opt = opt.strip() - if opt.startswith("w"): - width = int(opt[1:].strip()) - elif opt.startswith("h"): - height = int(opt[1:].strip()) - elif opt.startswith("s"): - steps = int(opt[1:].strip()) - elif opt.startswith("d"): - seed = int(opt[1:].strip()) - elif opt.startswith("g"): - guidance = float(opt[1:].strip()) - elif opt.startswith("m"): - mutipliers = opt[1:].strip().split(",") - if len(mutipliers) != len(lora_models): - logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") - continue - for i, lora_model in enumerate(lora_models): - lora_model.set_multiplier(float(mutipliers[i])) + try: + opt = opt.strip() + if opt.startswith("w"): + width = int(opt[1:].strip()) + elif opt.startswith("h"): + height = int(opt[1:].strip()) + elif opt.startswith("s"): + steps = int(opt[1:].strip()) + elif opt.startswith("d"): + seed = int(opt[1:].strip()) + elif opt.startswith("g"): + guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) + elif opt.startswith("n"): + negative_prompt = opt[1:].strip() + if negative_prompt == "-": + negative_prompt = "" + elif opt.startswith("c"): + cfg_scale = float(opt[1:].strip()) + except ValueError as e: + logger.error(f"Invalid option: {opt}, {e}") - generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance) + generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) logger.info("Done!")