From a9c5aa1f9336cedf1e294fd3c8c22bb649d51015 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 5 Jan 2025 22:28:51 +0900 Subject: [PATCH 1/7] add CFG to FLUX.1 sample image --- library/flux_train_utils.py | 152 ++++++++++++++++++++++++------------ 1 file changed, 104 insertions(+), 48 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..9f954f58 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -40,7 +40,7 @@ def sample_images( text_encoders, sample_prompts_te_outputs, prompt_replacement=None, - controlnet=None + controlnet=None, ): if steps == 0: if not args.sample_at_first: @@ -101,7 +101,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) else: # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) @@ -125,7 +125,7 @@ def sample_images( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ) torch.set_rng_state(rng_state) @@ -147,14 +147,14 @@ def sample_image_inference( steps, sample_prompts_te_outputs, prompt_replacement, - controlnet + controlnet, ): assert isinstance(prompt_dict, dict) - # negative_prompt = prompt_dict.get("negative_prompt") + negative_prompt = prompt_dict.get("negative_prompt") sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - scale = prompt_dict.get("scale", 3.5) + scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -162,8 +162,8 @@ def sample_image_inference( if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - # if negative_prompt is not None: - # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) if seed is not None: torch.manual_seed(seed) @@ -173,16 +173,18 @@ def sample_image_inference( torch.seed() torch.cuda.seed() - # if negative_prompt is None: - # negative_prompt = "" + if negative_prompt is None: + negative_prompt = "" height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") - # logger.info(f"negative_prompt: {negative_prompt}") + if scale != 1.0: + logger.info(f"negative_prompt: {negative_prompt}") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") - logger.info(f"scale: {scale}") + if scale != 1.0: + logger.info(f"scale: {scale}") # logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") @@ -191,26 +193,37 @@ def sample_image_inference( tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy() - text_encoder_conds = [] - if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs: - text_encoder_conds = sample_prompts_te_outputs[prompt] - print(f"Using cached text encoder outputs for prompt: {prompt}") - if text_encoders is not None: - print(f"Encoding prompt: {prompt}") - tokens_and_masks = tokenize_strategy.tokenize(prompt) - # strategy has apply_t5_attn_mask option - encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) + def encode_prompt(prpt): + text_encoder_conds = [] + if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: + text_encoder_conds = sample_prompts_te_outputs[prpt] + print(f"Using cached text encoder outputs for prompt: {prpt}") + if text_encoders is not None: + print(f"Encoding prompt: {prpt}") + tokens_and_masks = tokenize_strategy.tokenize(prpt) + # strategy has apply_t5_attn_mask option + encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - # if text_encoder_conds is not cached, use encoded_text_encoder_conds - if len(text_encoder_conds) == 0: - text_encoder_conds = encoded_text_encoder_conds - else: - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] + return text_encoder_conds - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt) + # encode negative prompts + if scale != 1.0: + neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt) + neg_t5_attn_mask = ( + neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None + ) + neg_cond = (scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) + else: + neg_cond = None # sample image weight_dtype = ae.dtype # TOFO give dtype as argument @@ -235,7 +248,20 @@ def sample_image_inference( controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) with accelerator.autocast(), torch.no_grad(): - x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) + x = denoise( + flux, + noise, + img_ids, + t5_out, + txt_ids, + l_pooled, + timesteps=timesteps, + guidance=scale, + t5_attn_mask=t5_attn_mask, + controlnet=controlnet, + controlnet_img=controlnet_image, + neg_cond=neg_cond, + ) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) @@ -305,22 +331,24 @@ def denoise( model: flux_models.Flux, img: torch.Tensor, img_ids: torch.Tensor, - txt: torch.Tensor, + txt: torch.Tensor, # t5_out txt_ids: torch.Tensor, - vec: torch.Tensor, + vec: torch.Tensor, # l_pooled timesteps: list[float], guidance: float = 4.0, t5_attn_mask: Optional[torch.Tensor] = None, controlnet: Optional[flux_models.ControlNetFlux] = None, controlnet_img: Optional[torch.Tensor] = None, + neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) - + do_cfg = neg_cond is not None for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() + if controlnet is not None: block_samples, block_single_samples = controlnet( img=img, @@ -336,20 +364,48 @@ def denoise( else: block_samples = None block_single_samples = None - pred = model( - img=img, - img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - block_controlnet_hidden_states=block_samples, - block_controlnet_single_hidden_states=block_single_samples, - timesteps=t_vec, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - img = img + (t_prev - t_curr) * pred + if not do_cfg: + pred = model( + img=img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + img = img + (t_prev - t_curr) * pred + else: + cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond + nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0) + + # TODO is it ok to use the same block samples for both cond and uncond? + block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0) + block_single_samples = ( + None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0) + ) + + nc_c_pred = model( + img=torch.cat([img, img], dim=0), + img_ids=torch.cat([img_ids, img_ids], dim=0), + txt=torch.cat([neg_t5_out, txt], dim=0), + txt_ids=torch.cat([txt_ids, txt_ids], dim=0), + y=torch.cat([neg_l_pooled, vec], dim=0), + block_controlnet_hidden_states=block_samples, + block_controlnet_single_hidden_states=block_single_samples, + timesteps=t_vec, + guidance=guidance_vec, + txt_attention_mask=nc_c_t5_attn_mask, + ) + neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0) + pred = neg_pred + (pred - neg_pred) * cfg_scale + + img = img + (t_prev - t_curr) * pred model.prepare_block_swap_before_forward() return img @@ -567,7 +623,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): "--controlnet_model_name_or_path", type=str, default=None, - help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" + help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)", ) parser.add_argument( "--t5xxl_max_token_length", From 629073cd9dd21296ca8aa97a5267d4dc7f6e5fdb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 16 Apr 2025 21:50:36 +0900 Subject: [PATCH 2/7] Add guidance scale for prompt param and flux sampling --- library/flux_train_utils.py | 10 +++++++--- library/train_util.py | 5 +++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index ce381829..d2ff347d 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -154,6 +154,7 @@ def sample_image_inference( sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) + guidance_scale = prompt_dict.get("guidance_scale", args.guidance_scale) scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") @@ -180,9 +181,12 @@ def sample_image_inference( logger.info(f"prompt: {prompt}") if scale != 1.0: logger.info(f"negative_prompt: {negative_prompt}") + elif negative_prompt != "": + logger.info(f"negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") + logger.info(f"guidance_scale: {guidance_scale}") if scale != 1.0: logger.info(f"scale: {scale}") # logger.info(f"sample_sampler: {sampler_name}") @@ -256,7 +260,7 @@ def sample_image_inference( txt_ids, l_pooled, timesteps=timesteps, - guidance=scale, + guidance=guidance_scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, @@ -489,7 +493,7 @@ def get_noisy_model_input_and_timesteps( sigmas = torch.randn(bsz, device=device) sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling sigmas = sigmas.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size sigmas = time_shift(mu, 1.0, sigmas) timesteps = sigmas * num_timesteps else: @@ -514,7 +518,7 @@ def get_noisy_model_input_and_timesteps( if args.ip_noise_gamma: xi = torch.randn_like(latents, device=latents.device, dtype=dtype) if args.ip_noise_gamma_random_strength: - ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) + ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma else: ip_noise_gamma = args.ip_noise_gamma noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) diff --git a/library/train_util.py b/library/train_util.py index 6c39f8d9..e152f30f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6178,6 +6178,11 @@ def line_to_prompt_dict(line: str) -> dict: prompt_dict["scale"] = float(m.group(1)) continue + m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE) + if m: # guidance scale + prompt_dict["guidance_scale"] = float(m.group(1)) + continue + m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt prompt_dict["negative_prompt"] = m.group(1) From 26db64be17835de5bff22bfc6d671ae1a2ffb4a4 Mon Sep 17 00:00:00 2001 From: Glen Date: Sat, 19 Apr 2025 11:54:12 -0600 Subject: [PATCH 3/7] fix: update hf_hub_download parameters to fix wd14 tagger regression --- finetune/tag_images_by_wd14_tagger.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6..f8f6ddd9 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -100,15 +100,19 @@ def main(args): else: for file in SUB_DIR_FILES: hf_hub_download( - args.repo_id, - file, + repo_id=args.repo_id, + filename=file, subfolder=SUB_DIR, - cache_dir=os.path.join(model_location, SUB_DIR), + local_dir=os.path.join(model_location, SUB_DIR), force_download=True, - force_filename=file, ) for file in files: - hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file) + hf_hub_download( + repo_id=args.repo_id, + filename=file, + local_dir=model_location, + force_download=True, + ) else: logger.info("using existing wd14 tagger model") From 8387e0b95c1067e919f91a2abec11ddcd5ed15cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Apr 2025 18:25:59 +0900 Subject: [PATCH 4/7] docs: update README to include CFG scale support in FLUX.1 training --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2e80a697..f9831aee 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates +Apr 27, 2025: +- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). + - See [here](#sample-image-generation-during-training) for details. + Apr 6, 2025: - IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details. - `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available. @@ -1344,11 +1348,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used. - * `--n` Negative prompt up to the next option. + * `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`. * `--w` Specifies the width of the generated image. * `--h` Specifies the height of the generated image. * `--d` Specifies the seed of the generated image. * `--l` Specifies the CFG scale of the generated image. + * In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility. + * `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models. * `--s` Specifies the number of steps in the generation. The prompt weighting such as `( )` and `[ ]` are working. From f0b07c52abaf4ab33d619b427afabe17b69b7d05 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 27 Apr 2025 21:28:38 +0900 Subject: [PATCH 5/7] Create FUNDING.yml --- .github/FUNDING.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000..3b8943c3 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: kohya-ss From fd3a445769910ddc0c8c02d13e535cac37b85d2e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Apr 2025 22:50:27 +0900 Subject: [PATCH 6/7] fix: revert default emb guidance scale and CFG scale for FLUX.1 sampling --- library/flux_train_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d2ff347d..5f6867a8 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -154,8 +154,9 @@ def sample_image_inference( sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - guidance_scale = prompt_dict.get("guidance_scale", args.guidance_scale) - scale = prompt_dict.get("scale", 1.0) # 1.0 means no guidance + # TODO refactor variable names + cfg_scale = prompt_dict.get("guidance_scale", 1.0) + emb_guidance_scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") @@ -179,16 +180,16 @@ def sample_image_inference( height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") - if scale != 1.0: + if cfg_scale != 1.0: logger.info(f"negative_prompt: {negative_prompt}") elif negative_prompt != "": logger.info(f"negative prompt is ignored because scale is 1.0") logger.info(f"height: {height}") logger.info(f"width: {width}") logger.info(f"sample_steps: {sample_steps}") - logger.info(f"guidance_scale: {guidance_scale}") - if scale != 1.0: - logger.info(f"scale: {scale}") + logger.info(f"embedded guidance scale: {emb_guidance_scale}") + if cfg_scale != 1.0: + logger.info(f"CFG scale: {cfg_scale}") # logger.info(f"sample_sampler: {sampler_name}") if seed is not None: logger.info(f"seed: {seed}") @@ -220,12 +221,12 @@ def sample_image_inference( l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt) # encode negative prompts - if scale != 1.0: + if cfg_scale != 1.0: neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt) neg_t5_attn_mask = ( neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None ) - neg_cond = (scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) + neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask) else: neg_cond = None @@ -260,7 +261,7 @@ def sample_image_inference( txt_ids, l_pooled, timesteps=timesteps, - guidance=guidance_scale, + guidance=emb_guidance_scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image, From 29523c9b68bd56cdb1cce3f4985f2e45cefb1f2b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 27 Apr 2025 23:34:37 +0900 Subject: [PATCH 7/7] docs: add note for user feedback on CFG scale in FLUX.1 training --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f9831aee..18e8e659 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ The command to install PyTorch is as follows: Apr 27, 2025: - FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064). - See [here](#sample-image-generation-during-training) for details. + - If you have any issues with this, please let us know. Apr 6, 2025: - IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.