mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add FLUX.1 support
This commit is contained in:
19
README.md
19
README.md
@@ -11,6 +11,25 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Oct 19, 2024:
|
||||||
|
|
||||||
|
- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training.
|
||||||
|
- A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words.
|
||||||
|
- Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`.
|
||||||
|
- See [dataset configuration](docs/config_README-en.md) for the regularization dataset.
|
||||||
|
- Specify "number of training images x number of epochs >= number of regularization images x number of epochs".
|
||||||
|
- Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000.
|
||||||
|
- Set the loss in the training without using the regularization image to be close to the loss in the training using DOP.
|
||||||
|
```
|
||||||
|
[[datasets.subsets]]
|
||||||
|
image_dir = "path/to/image/dir"
|
||||||
|
num_repeats = 1
|
||||||
|
is_reg = true
|
||||||
|
custom_attributes.diff_output_preservation = true # Add this
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Oct 13, 2024:
|
Oct 13, 2024:
|
||||||
|
|
||||||
- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large.
|
- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large.
|
||||||
|
|||||||
@@ -373,33 +373,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if not args.apply_t5_attn_mask:
|
if not args.apply_t5_attn_mask:
|
||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
if not args.split_mode:
|
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||||
# normal forward
|
if not args.split_mode:
|
||||||
with accelerator.autocast():
|
# normal forward
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
with accelerator.autocast():
|
||||||
model_pred = unet(
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
img=packed_noisy_model_input,
|
model_pred = unet(
|
||||||
img_ids=img_ids,
|
img=img,
|
||||||
txt=t5_out,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=l_pooled,
|
|
||||||
timesteps=timesteps / 1000,
|
|
||||||
guidance=guidance_vec,
|
|
||||||
txt_attention_mask=t5_attn_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# split forward to reduce memory usage
|
|
||||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
|
||||||
with accelerator.autocast():
|
|
||||||
# move flux lower to cpu, and then move flux upper to gpu
|
|
||||||
unet.to("cpu")
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
self.flux_upper.to(accelerator.device)
|
|
||||||
|
|
||||||
# upper model does not require grad
|
|
||||||
with torch.no_grad():
|
|
||||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
|
||||||
img=packed_noisy_model_input,
|
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
txt=t5_out,
|
txt=t5_out,
|
||||||
txt_ids=txt_ids,
|
txt_ids=txt_ids,
|
||||||
@@ -408,18 +388,52 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
txt_attention_mask=t5_attn_mask,
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# split forward to reduce memory usage
|
||||||
|
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
||||||
|
with accelerator.autocast():
|
||||||
|
# move flux lower to cpu, and then move flux upper to gpu
|
||||||
|
unet.to("cpu")
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
self.flux_upper.to(accelerator.device)
|
||||||
|
|
||||||
# move flux upper back to cpu, and then move flux lower to gpu
|
# upper model does not require grad
|
||||||
self.flux_upper.to("cpu")
|
with torch.no_grad():
|
||||||
clean_memory_on_device(accelerator.device)
|
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
||||||
unet.to(accelerator.device)
|
img=packed_noisy_model_input,
|
||||||
|
img_ids=img_ids,
|
||||||
|
txt=t5_out,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
y=l_pooled,
|
||||||
|
timesteps=timesteps / 1000,
|
||||||
|
guidance=guidance_vec,
|
||||||
|
txt_attention_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
# lower model requires grad
|
# move flux upper back to cpu, and then move flux lower to gpu
|
||||||
intermediate_img.requires_grad_(True)
|
self.flux_upper.to("cpu")
|
||||||
intermediate_txt.requires_grad_(True)
|
clean_memory_on_device(accelerator.device)
|
||||||
vec.requires_grad_(True)
|
unet.to(accelerator.device)
|
||||||
pe.requires_grad_(True)
|
|
||||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
# lower model requires grad
|
||||||
|
intermediate_img.requires_grad_(True)
|
||||||
|
intermediate_txt.requires_grad_(True)
|
||||||
|
vec.requires_grad_(True)
|
||||||
|
pe.requires_grad_(True)
|
||||||
|
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||||
|
|
||||||
|
return model_pred
|
||||||
|
|
||||||
|
model_pred = call_dit(
|
||||||
|
img=packed_noisy_model_input,
|
||||||
|
img_ids=img_ids,
|
||||||
|
t5_out=t5_out,
|
||||||
|
txt_ids=txt_ids,
|
||||||
|
l_pooled=l_pooled,
|
||||||
|
timesteps=timesteps,
|
||||||
|
guidance_vec=guidance_vec,
|
||||||
|
t5_attn_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
# unpack latents
|
# unpack latents
|
||||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||||
@@ -430,6 +444,37 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# flow matching loss: this is different from SD3
|
# flow matching loss: this is different from SD3
|
||||||
target = noise - latents
|
target = noise - latents
|
||||||
|
|
||||||
|
# differential output preservation
|
||||||
|
if "custom_attributes" in batch:
|
||||||
|
diff_output_pr_indices = []
|
||||||
|
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||||
|
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||||
|
diff_output_pr_indices.append(i)
|
||||||
|
|
||||||
|
if len(diff_output_pr_indices) > 0:
|
||||||
|
network.set_multiplier(0.0)
|
||||||
|
with torch.no_grad(), accelerator.autocast():
|
||||||
|
model_pred_prior = call_dit(
|
||||||
|
img=packed_noisy_model_input[diff_output_pr_indices],
|
||||||
|
img_ids=img_ids[diff_output_pr_indices],
|
||||||
|
t5_out=t5_out[diff_output_pr_indices],
|
||||||
|
txt_ids=txt_ids[diff_output_pr_indices],
|
||||||
|
l_pooled=l_pooled[diff_output_pr_indices],
|
||||||
|
timesteps=timesteps[diff_output_pr_indices],
|
||||||
|
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||||
|
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||||
|
)
|
||||||
|
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||||
|
|
||||||
|
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
||||||
|
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
||||||
|
args,
|
||||||
|
model_pred_prior,
|
||||||
|
noisy_model_input[diff_output_pr_indices],
|
||||||
|
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||||||
|
)
|
||||||
|
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||||
|
|
||||||
return model_pred, target, timesteps, None, weighting
|
return model_pred, target, timesteps, None, weighting
|
||||||
|
|
||||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
|
|||||||
Reference in New Issue
Block a user