mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Diff Output Preserv loss for SDXL
This commit is contained in:
@@ -396,6 +396,7 @@ class BaseSubset:
|
||||
caption_suffix: Optional[str],
|
||||
token_warmup_min: int,
|
||||
token_warmup_step: Union[float, int],
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -419,6 +420,8 @@ class BaseSubset:
|
||||
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
|
||||
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
|
||||
|
||||
self.custom_attributes = custom_attributes if custom_attributes is not None else {}
|
||||
|
||||
self.img_count = 0
|
||||
|
||||
|
||||
@@ -449,6 +452,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -473,6 +477,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -512,6 +517,7 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -536,6 +542,7 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_suffix,
|
||||
token_warmup_min,
|
||||
token_warmup_step,
|
||||
custom_attributes=custom_attributes,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -1474,11 +1481,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs_list = []
|
||||
custom_attributes = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
|
||||
custom_attributes.append(subset.custom_attributes)
|
||||
|
||||
# in case of fine tuning, is_reg is always False
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
@@ -1646,7 +1656,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return None
|
||||
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
|
||||
|
||||
# set example
|
||||
example = {}
|
||||
example["custom_attributes"] = custom_attributes # may be list of empty dict
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
|
||||
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
|
||||
@@ -2630,7 +2642,9 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
)
|
||||
if "network_multipliers" in example:
|
||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
logger.info(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
if "custom_attributes" in example:
|
||||
logger.info(f"custom attributes: {example['custom_attributes'][j]}")
|
||||
|
||||
# if show_input_ids:
|
||||
# logger.info(f"input ids: {iid}")
|
||||
@@ -4091,6 +4105,7 @@ def enable_high_vram(args: argparse.Namespace):
|
||||
global HIGH_VRAM
|
||||
HIGH_VRAM = True
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
r"""
|
||||
Verify training arguments. Also reflect highvram option to global variable
|
||||
|
||||
Reference in New Issue
Block a user