mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Merge pull request #1710 from kohya-ss/diff_output_prsv
Differential Output Preservation loss for LoRA
This commit is contained in:
17
README.md
17
README.md
@@ -11,6 +11,23 @@ The command to install PyTorch is as follows:
|
||||
|
||||
### Recent Updates
|
||||
|
||||
Oct 19, 2024:
|
||||
|
||||
- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature.
|
||||
- A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words.
|
||||
- Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`.
|
||||
- See [dataset configuration](docs/config_README-en.md) for the regularization dataset.
|
||||
- Specify "number of training images x number of repeats >= number of regularization images x number of repeats".
|
||||
- Specify a large value for `--prior_loss_weight` option (not dataset config). The appropriate value is unknown, but try around 10-100. Note that the default is 1.0.
|
||||
- You may want to start with 2/3 to 3/4 of the loss value when DOP is not applied. If it is 1/2, DOP may not be working.
|
||||
```
|
||||
[[datasets.subsets]]
|
||||
image_dir = "path/to/image/dir"
|
||||
num_repeats = 1
|
||||
is_reg = true
|
||||
custom_attributes.diff_output_preservation = true # Add this
|
||||
```
|
||||
|
||||
Oct 13, 2024:
|
||||
|
||||
- Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large.
|
||||
|
||||
@@ -373,33 +373,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
if not args.split_mode:
|
||||
# normal forward
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
# split forward to reduce memory usage
|
||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
||||
with accelerator.autocast():
|
||||
# move flux lower to cpu, and then move flux upper to gpu
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
self.flux_upper.to(accelerator.device)
|
||||
|
||||
# upper model does not require grad
|
||||
with torch.no_grad():
|
||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
||||
img=packed_noisy_model_input,
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
if not args.split_mode:
|
||||
# normal forward
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
@@ -408,18 +388,52 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
# split forward to reduce memory usage
|
||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
||||
with accelerator.autocast():
|
||||
# move flux lower to cpu, and then move flux upper to gpu
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
self.flux_upper.to(accelerator.device)
|
||||
|
||||
# move flux upper back to cpu, and then move flux lower to gpu
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
# upper model does not require grad
|
||||
with torch.no_grad():
|
||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# lower model requires grad
|
||||
intermediate_img.requires_grad_(True)
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
# move flux upper back to cpu, and then move flux lower to gpu
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# lower model requires grad
|
||||
intermediate_img.requires_grad_(True)
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
t5_out=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
l_pooled=l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance_vec=guidance_vec,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
@@ -430,6 +444,37 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=packed_noisy_model_input[diff_output_pr_indices],
|
||||
img_ids=img_ids[diff_output_pr_indices],
|
||||
t5_out=t5_out[diff_output_pr_indices],
|
||||
txt_ids=txt_ids[diff_output_pr_indices],
|
||||
l_pooled=l_pooled[diff_output_pr_indices],
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
|
||||
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
|
||||
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
||||
args,
|
||||
model_pred_prior,
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, None, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
|
||||
@@ -10,13 +10,7 @@ import json
|
||||
from pathlib import Path
|
||||
|
||||
# from toolz import curry
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import toml
|
||||
import voluptuous
|
||||
@@ -78,6 +72,7 @@ class BaseSubsetParams:
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: float = 0
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -197,6 +192,7 @@ class ConfigSanitizer:
|
||||
"token_warmup_step": Any(float, int),
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -538,9 +534,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
flip_aug: {subset.flip_aug}
|
||||
face_crop_aug_range: {subset.face_crop_aug_range}
|
||||
random_crop: {subset.random_crop}
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask},
|
||||
token_warmup_min: {subset.token_warmup_min}
|
||||
token_warmup_step: {subset.token_warmup_step}
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""
|
||||
),
|
||||
" ",
|
||||
|
||||
@@ -396,6 +396,7 @@ class BaseSubset:
|
||||
caption_suffix: Optional[str],
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -419,6 +420,8 @@ class BaseSubset:
|
||||
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
|
||||
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
|
||||
|
||||
self.custom_attributes = custom_attributes if custom_attributes is not None else {}
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
|
||||
@@ -449,6 +452,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -473,6 +477,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -512,6 +517,7 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -536,6 +542,7 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -1474,11 +1481,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs_list = []
|
||||
custom_attributes = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
|
||||
# in case of fine tuning, is_reg is always False
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
@@ -1646,7 +1656,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return None
|
||||
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
|
||||
|
||||
# set example
|
||||
example = {}
|
||||
example["custom_attributes"] = custom_attributes # may be list of empty dict
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
|
||||
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
|
||||
@@ -2630,7 +2642,9 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
)
|
||||
if "network_multipliers" in example:
|
||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
logger.info(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
if "custom_attributes" in example:
|
||||
logger.info(f"custom attributes: {example['custom_attributes'][j]}")
|
||||
|
||||
# if show_input_ids:
|
||||
# logger.info(f"input ids: {iid}")
|
||||
@@ -4091,6 +4105,7 @@ def enable_high_vram(args: argparse.Namespace):
|
||||
global HIGH_VRAM
|
||||
HIGH_VRAM = True
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
r"""
|
||||
Verify training arguments. Also reflect highvram option to global variable
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@@ -172,7 +173,18 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
def call_unet(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
text_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
indices: Optional[List[int]] = None,
|
||||
):
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
# get size embeddings
|
||||
@@ -186,6 +198,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
||||
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
||||
|
||||
if indices is not None and len(indices) > 0:
|
||||
noisy_latents = noisy_latents[indices]
|
||||
timesteps = timesteps[indices]
|
||||
text_embedding = text_embedding[indices]
|
||||
vector_embedding = vector_embedding[indices]
|
||||
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ class NetworkTrainer:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||
return noise_pred
|
||||
|
||||
@@ -218,6 +218,30 @@ class NetworkTrainer:
|
||||
else:
|
||||
target = noise
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
noise_pred_prior = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
indices=diff_output_pr_indices,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, target, timesteps, huber_c, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
@@ -1123,15 +1147,6 @@ class NetworkTrainer:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# # SD only
|
||||
# encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
# tokenizers[0],
|
||||
# text_encoder,
|
||||
# batch["captions"],
|
||||
# accelerator.device,
|
||||
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||
# clip_skip=args.clip_skip,
|
||||
# )
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
|
||||
Reference in New Issue
Block a user