From ef70aa7b42b5c923cc1a8594b2f30487a2b4f700 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 18 Oct 2024 23:39:48 +0900 Subject: [PATCH] add FLUX.1 support --- README.md | 19 +++++++ flux_train_network.py | 119 +++++++++++++++++++++++++++++------------- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 7fae50d1..59f70ebc 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,25 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 19, 2024: + +- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. + - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. + - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. + - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. + - Specify "number of training images x number of epochs >= number of regularization images x number of epochs". + - Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000. + - Set the loss in the training without using the regularization image to be close to the loss in the training using DOP. +``` +[[datasets.subsets]] +image_dir = "path/to/image/dir" +num_repeats = 1 +is_reg = true +custom_attributes.diff_output_preservation = true # Add this +``` + + + Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. diff --git a/flux_train_network.py b/flux_train_network.py index aa92fe3a..8431a6dc 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -373,33 +373,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if not args.apply_t5_attn_mask: t5_attn_mask = None - if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, img_ids=img_ids, txt=t5_out, txt_ids=txt_ids, @@ -408,18 +388,52 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) @@ -430,6 +444,37 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # flow matching loss: this is different from SD3 target = noise - latents + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + return model_pred, target, timesteps, None, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler):