mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add network_multiplier for dataset and train LoRA
This commit is contained in:
@@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
||||
max_token_length: int,
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.max_token_length = max_token_length
|
||||
# width/height is used when enable_bucket==False
|
||||
self.width, self.height = (None, None) if resolution is None else resolution
|
||||
self.network_multiplier = network_multiplier
|
||||
self.debug_dataset = debug_dataset
|
||||
|
||||
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
||||
@@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
loss_weights.append(
|
||||
self.prior_loss_weight if image_info.is_reg else 1.0
|
||||
) # in case of fine tuning, is_reg is always False
|
||||
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
|
||||
@@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
|
||||
example["flippeds"] = flippeds
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset,
|
||||
debug_dataset: float,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2039,6 +2048,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
)
|
||||
if "network_multipliers" in example:
|
||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
@@ -2105,8 +2116,8 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
@@ -2850,14 +2861,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
|
||||
parser.add_argument(
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
# available backends:
|
||||
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)"
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument(
|
||||
@@ -2904,9 +2915,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument(
|
||||
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
|
||||
)
|
||||
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
@@ -3889,7 +3898,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
os.environ["WANDB_DIR"] = logging_dir
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
|
||||
# torch.compile のオプション。 NO の場合は torch.compile は使わない
|
||||
dynamo_backend = "NO"
|
||||
if args.torch_compile:
|
||||
|
||||
Reference in New Issue
Block a user