diff --git a/README.md b/README.md index 663f52e8..c1043783 100644 --- a/README.md +++ b/README.md @@ -257,11 +257,42 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version. - The sample image generation during training consumes a lot of memory. It is recommended to turn it off. +- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc. + - This is an experimental option and may be removed or changed in the future. + - For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate. + - Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate. + - Please specify `network_multiplier` in `[[datasets]]` in `.toml` file. + - (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。 - `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。 - PyTorch 2.1 以降が必要です。 - PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。 - 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。 +- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。 + - 実験的オプションのため、将来的に削除または仕様変更される可能性があります。 + - たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。 + - また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。 + - `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。 + +- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例 + +```toml +[general] +[[datasets]] +resolution = 512 +batch_size = 8 +network_multiplier = 1.0 + +... subset settings ... + +[[datasets]] +resolution = 512 +batch_size = 8 +network_multiplier = -1.0 + +... subset settings ... +``` + ### Jan 17, 2024 / 2024/1/17: v0.8.1 diff --git a/library/config_util.py b/library/config_util.py index 716cecff..a98c2b90 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -93,6 +93,7 @@ class BaseDatasetParams: tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None max_token_length: int = None resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 debug_dataset: bool = False @@ -219,6 +220,7 @@ class ConfigSanitizer: "max_bucket_reso": int, "min_bucket_reso": int, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, } # options handled by argparse but not handled by user config @@ -469,6 +471,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} enable_bucket: {dataset.enable_bucket} + network_multiplier: {dataset.network_multiplier} """ ) diff --git a/library/train_util.py b/library/train_util.py index 21e7638d..4ac6728b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset): tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]], max_token_length: int, resolution: Optional[Tuple[int, int]], + network_multiplier: float, debug_dataset: bool, ) -> None: super().__init__() @@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset): self.max_token_length = max_token_length # width/height is used when enable_bucket==False self.width, self.height = (None, None) if resolution is None else resolution + self.network_multiplier = network_multiplier self.debug_dataset = debug_dataset self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = [] @@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset): for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + loss_weights.append( + self.prior_loss_weight if image_info.is_reg else 1.0 + ) # in case of fine tuning, is_reg is always False flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance @@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset): example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw]) example["flippeds"] = flippeds + example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, prior_loss_weight: float, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: bool, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.batch_size = batch_size @@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset): tokenizer, max_token_length, resolution, + network_multiplier: float, enable_bucket: bool, min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int, bucket_no_upscale: bool, - debug_dataset, + debug_dataset: float, ) -> None: - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) db_subsets = [] for subset in subsets: @@ -2039,6 +2048,8 @@ def debug_dataset(train_dataset, show_input_ids=False): print( f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) + if "network_multipliers" in example: + print(f"network multiplier: {example['network_multipliers'][j]}") if show_input_ids: print(f"input ids: {iid}") @@ -2105,8 +2116,8 @@ def glob_images_pathlib(dir_path, recursive): class MinimalDataset(BaseDataset): - def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False): - super().__init__(tokenizer, max_token_length, resolution, debug_dataset) + def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False): + super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset) self.num_train_images = 0 # update in subclass self.num_reg_images = 0 # update in subclass @@ -2850,14 +2861,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") parser.add_argument( - "--dynamo_backend", - type=str, - default="inductor", + "--dynamo_backend", + type=str, + default="inductor", # available backends: # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 # https://pytorch.org/docs/stable/torch.compiler.html - choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], - help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( @@ -2904,9 +2915,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 - parser.add_argument( - "--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う" - ) + parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う") parser.add_argument( "--ddp_timeout", type=int, @@ -3889,7 +3898,7 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) - + # torch.compile のオプション。 NO の場合は torch.compile は使わない dynamo_backend = "NO" if args.torch_compile: diff --git a/train_network.py b/train_network.py index b1291ed1..ef7d4197 100644 --- a/train_network.py +++ b/train_network.py @@ -310,6 +310,7 @@ class NetworkTrainer: ) if network is None: return + network_has_multiplier = hasattr(network, "set_multiplier") if hasattr(network, "prepare_network"): network.prepare_network(args) @@ -768,7 +769,17 @@ class NetworkTrainer: accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) latents = latents * self.vae_scale_factor - b_size = latents.shape[0] + + # get multiplier for each sample + if network_has_multiplier: + multipliers = batch["network_multipliers"] + # if all multipliers are same, use single multiplier + if torch.all(multipliers == multipliers[0]): + multipliers = multipliers[0].item() + else: + raise NotImplementedError("multipliers for each sample is not supported yet") + # print(f"set multiplier: {multipliers}") + network.set_multiplier(multipliers) with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning