diff --git a/README.md b/README.md index 7df4f26c..14970cc0 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,33 @@ This repository contains training, generation and utility scripts for Stable Dif ## Updates -- 22 Jan. 2023, 2023/1/22 - - Fix script to check LoRA weights ``check_lora_weights.py``. Some layer weights were shown as ``0.0`` even if the layer is trained, because of the overflow of ``torch.mean``. Sorry for the confusion. - - Noe the script shows the mean of the absolute values of the weights, and the minimum of the absolute values of the weights. - - LoRAの重みをチェックするスクリプト ``check_lora_weights.py`` を修正しました。一部のレイヤーで学習されているにもかかわらず重みが ``0.0`` と表示されていました。混乱を招き申し訳ありません。 - - スクリプトを「重みの絶対の平均」と「重みの絶対値の最小値」を表示するよう修正しました。 +__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!! + +Note: Currently the LoRA models trained by release 0.4.0 does not seem to be supported. If you use Web UI native LoRA support, please use release 0.3.2 for now. + +The LoRA models for SD 2.x is not supported too. + +- 22 Jan. 2023 + - Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe! + - Details of the issue are described in https://github.com/kohya-ss/sd-webui-additional-networks/issues/49 . + - The default value is ``1``, scale ``1 / rank (or dimension)``. Set same value as ``network_dim`` for same behavior to old version. + - Add logging for the learning rate for U-Net and Text Encoder independently, and for running average epoch loss. Thanks to mgz-dev! + - Add more metadata such as dataset/reg image dirs, session ID, output name etc... See #77 for details. Thanks to space-nuko! + - __Now the metadata includes the folder name (the basename of the folder contains image files, not fullpath).__ If you do not want it, disable metadata storing with ``--no_metadata`` option. + - Add ``--training_comment`` option. You can specify an arbitrary string and refer to it by the extension. + +Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。 + +注:現時点ではversion 0.4.0で学習したモデルはサポートされないようです。Web UI本体の生成機能を使う場合には、version 0.3.2を引き続きご利用ください。またSD2.x用のLoRAモデルもサポートされないようです。 + +- 2023/1/22 + - アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定する、``--network_alpha`` オプションを追加しました。CCRcmcpe 氏に感謝します。 + - 問題の詳細はこちらをご覧ください: https://github.com/kohya-ss/sd-webui-additional-networks/issues/49 + - デフォルト値は ``1`` で、重みを ``1 / rank (dimension・次元数)`` します。``network_dim`` と同じ値を指定すると旧バージョンと同じ動作になります。 + - U-Net と Text Encoder のそれぞれの学習率、エポックの平均lossをログに記録するようになりました。mgz-dev 氏に感謝します。 + - 画像ディレクトリ、セッションID、出力名などいくつかの項目がメタデータに追加されました(詳細は #77 を参照)。space-nuko氏に感謝します。 + - __メタデータにフォルダ名が含まれるようになりました(画像を含むフォルダの名前のみで、フルパスではありません)。__ もし望まない場合には ``--no_metadata`` オプションでメタデータの記録を止めてください。 + - ``--training_comment`` オプションを追加しました。任意の文字列を指定でき、Web UI拡張から参照できます。 Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates. 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。 diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 7b4ef2e5..19c63acf 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1981,7 +1981,6 @@ def main(args): imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i] net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -1992,22 +1991,22 @@ def main(args): key, value = net_arg.split("=") net_kwargs[key] = value - network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs) - if network is None: - return - if args.network_weights and i < len(args.network_weights): network_weight = args.network_weights[i] print("load network weights from:", network_weight) - if os.path.splitext(network_weight)[1] == '.safetensors': + if model_util.is_safetensors(network_weight): from safetensors.torch import safe_open with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() if metadata is not None: print(f"metadata for: {network_weight}: {metadata}") - network.load_weights(network_weight) + network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return network.apply_to(text_encoder, unet) @@ -2526,8 +2525,6 @@ if __name__ == '__main__': parser.add_argument("--network_weights", type=str, default=None, nargs='*', help='Hypernetwork weights to load / Hypernetworkの重み') parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') - parser.add_argument("--network_dim", type=int, default=None, nargs='*', - help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') diff --git a/library/train_util.py b/library/train_util.py index aa65dc3c..0fdbadc1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -11,6 +11,7 @@ import glob import math import os import random +import hashlib from tqdm import tqdm import torch @@ -79,6 +80,11 @@ class BaseDataset(torch.utils.data.Dataset): self.debug_dataset = debug_dataset self.random_crop = random_crop self.token_padding_disabled = False + self.dataset_dirs_info = {} + self.reg_dataset_dirs_info = {} + self.enable_bucket = False + self.min_bucket_reso = None + self.max_bucket_reso = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -463,6 +469,8 @@ class DreamBoothDataset(BaseDataset): assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( (self.width, self.height), min_bucket_reso, max_bucket_reso) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso else: self.bucket_resos = [(self.width, self.height)] self.bucket_aspect_ratios = [self.width / self.height] @@ -523,6 +531,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) + self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images @@ -539,6 +548,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, True, img_path) reg_infos.append(info) + self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: @@ -611,6 +621,8 @@ class FineTuningDataset(BaseDataset): self.num_train_images = len(metadata) * dataset_repeats self.num_reg_images = 0 + self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} + # check existence of all npz files if not self.color_aug: npz_any = False @@ -653,6 +665,8 @@ class FineTuningDataset(BaseDataset): assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( (self.width, self.height), min_bucket_reso, max_bucket_reso) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso else: self.bucket_resos = [(self.width, self.height)] self.bucket_aspect_ratios = [self.width / self.height] @@ -665,6 +679,9 @@ class FineTuningDataset(BaseDataset): self.bucket_resos.sort() self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos] + self.min_bucket_reso = min([min(reso) for reso in resos]) + self.max_bucket_reso = max([max(reso) for reso in resos]) + def image_key_to_npz_file(self, image_key): base_name = os.path.splitext(image_key)[0] npz_file_norm = base_name + '.npz' @@ -749,9 +766,9 @@ def default(val, d): def model_hash(filename): + """Old model hash used by stable-diffusion-webui""" try: with open(filename, "rb") as file: - import hashlib m = hashlib.sha256() file.seek(0x100000) @@ -761,6 +778,18 @@ def model_hash(filename): return 'NOFILE' +def calculate_sha256(filename): + """New model hash used by stable-diffusion-webui""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 @@ -1029,7 +1058,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") + parser.add_argument("--save_last_n_epochs_state", type=int, default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1048,8 +1078,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") - parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") + parser.add_argument("--max_train_epochs", type=int, default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") + parser.add_argument("--max_data_loader_n_workers", type=int, default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする") diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 0a4c3a00..84d705cf 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -44,9 +44,9 @@ def svd(args): print(f"loading SD model : {args.model_tuned}") text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) - # create LoRA network to extract weights - lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o) - lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t) + # create LoRA network to extract weights: Use dim (rank) as alpha + lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o) + lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -77,10 +77,10 @@ def svd(args): module_t = lora_t.org_module diff = module_t.weight - module_o.weight diff = diff.float() - + if args.device: diff = diff.to(args.device) - + diffs[lora_name] = diff # make LoRA with svd @@ -116,6 +116,9 @@ def svd(args): print(f"LoRA has {len(lora_sd)} weights.") for key in list(lora_sd.keys()): + if "alpha" in key: + continue + lora_name = key.split('.')[0] i = 0 if "lora_up" in key else 1 @@ -124,7 +127,7 @@ def svd(args): if len(lora_sd[key].size()) == 4: weights = weights.unsqueeze(2).unsqueeze(3) - assert weights.size() == lora_sd[key].size() + assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}" lora_sd[key] = weights # load state dict to LoRA and save it @@ -135,7 +138,10 @@ def svd(args): if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) - lora_network_o.save_weights(args.save_to, save_dtype, {}) + # minimum metadata + metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)} + + lora_network_o.save_weights(args.save_to, save_dtype, metadata) print(f"LoRA weights are saved to: {args.save_to}") @@ -151,8 +157,8 @@ if __name__ == '__main__': help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") parser.add_argument("--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") - parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)") - parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") + parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") + parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") args = parser.parse_args() svd(args) diff --git a/networks/lora.py b/networks/lora.py index 3f8244e0..9243f1e1 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -13,9 +13,11 @@ class LoRAModule(torch.nn.Module): replaces forward method of the original Linear, instead of replacing the original Linear module. """ - def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): + def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1): + """ if alpha == 0 or None, alpha is rank (no scaling). """ super().__init__() self.lora_name = lora_name + self.lora_dim = lora_dim if org_module.__class__.__name__ == 'Conv2d': in_dim = org_module.in_channels @@ -28,6 +30,12 @@ class LoRAModule(torch.nn.Module): self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_up.weight) @@ -41,13 +49,37 @@ class LoRAModule(torch.nn.Module): del self.org_module def forward(self, x): - return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier + return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale -def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): +def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): if network_dim is None: network_dim = 4 # default - network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs): + if os.path.splitext(file)[1] == '.safetensors': + from safetensors.torch import load_file, safe_open + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location='cpu') + + # get dim (rank) + network_alpha = None + network_dim = None + for key, value in weights_sd.items(): + if network_alpha is None and 'alpha' in key: + network_alpha = value + if network_dim is None and 'lora_down' in key and len(value.size()) == 2: + network_dim = value.size()[0] + + if network_alpha is None: + network_alpha = network_dim + + network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha) + network.weights_sd = weights_sd return network @@ -57,10 +89,11 @@ class LoRANetwork(torch.nn.Module): LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_TEXT_ENCODER = 'lora_te' - def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: + def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None: super().__init__() self.multiplier = multiplier self.lora_dim = lora_dim + self.alpha = alpha # create module instances def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: @@ -71,7 +104,7 @@ class LoRANetwork(torch.nn.Module): if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') - lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) + lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha) loras.append(lora) return loras @@ -149,21 +182,21 @@ class LoRANetwork(torch.nn.Module): return params self.requires_grad_(True) - params = [] + all_params = [] if self.text_encoder_loras: param_data = {'params': enumerate_params(self.text_encoder_loras)} if text_encoder_lr is not None: param_data['lr'] = text_encoder_lr - params.append(param_data) + all_params.append(param_data) if self.unet_loras: param_data = {'params': enumerate_params(self.unet_loras)} if unet_lr is not None: param_data['lr'] = unet_lr - params.append(param_data) + all_params.append(param_data) - return params + return all_params def prepare_grad_etc(self, text_encoder, unet): self.requires_grad_(True) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index d873a8ef..1d4cb3b5 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): for key in lora_sd.keys(): if "lora_down" in key: up_key = key.replace("lora_down", "lora_up") + alpha_key = key[:key.index("lora_down")] + 'alpha' # find original module for this lora module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" @@ -73,14 +74,18 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): down_weight = lora_sd[key] up_weight = lora_sd[up_key] + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + # W <- W + U * D weight = module.weight if len(weight.size()) == 2: # linear - weight = weight + ratio * (up_weight @ down_weight) + weight = weight + ratio * (up_weight @ down_weight) * scale else: # conv2d - weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale module.weight = torch.nn.Parameter(weight) @@ -88,20 +93,35 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype): merged_sd = {} + alpha = None + dim = None for model, ratio in zip(models, ratios): print(f"loading: {model}") lora_sd = load_state_dict(model, merge_dtype) print(f"merging...") for key in lora_sd.keys(): - if key in merged_sd: - assert merged_sd[key].size() == lora_sd[key].size( - ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + if 'alpha' in key: + if key in merged_sd: + assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません" + else: + alpha = lora_sd[key].detach().numpy() + merged_sd[key] = lora_sd[key] else: - merged_sd[key] = lora_sd[key] * ratio + if key in merged_sd: + assert merged_sd[key].size() == lora_sd[key].size( + ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" + merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio + else: + if "lora_down" in key: + dim = lora_sd[key].size()[0] + merged_sd[key] = lora_sd[key] * ratio - return merged_sd + print(f"dim (rank): {dim}, alpha: {alpha}") + if alpha is None: + alpha = dim + + return merged_sd, dim, alpha def merge(args): @@ -132,7 +152,7 @@ def merge(args): model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae) else: - state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) print(f"saving model to: {args.save_to}") save_to_file(args.save_to, state_dict, state_dict, save_dtype) @@ -145,7 +165,7 @@ if __name__ == '__main__': parser.add_argument("--save_precision", type=str, default=None, choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") parser.add_argument("--precision", type=str, default="float", - choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度") + choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)") parser.add_argument("--sd_model", type=str, default=None, help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") parser.add_argument("--save_to", type=str, default=None, diff --git a/train_network.py b/train_network.py index b2c7b579..d60ae9a0 100644 --- a/train_network.py +++ b/train_network.py @@ -3,6 +3,9 @@ import argparse import gc import math import os +import random +import time +import json from tqdm import tqdm import torch @@ -18,7 +21,23 @@ def collate_fn(examples): return examples[0] +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if args.network_train_unet_only: + logs["lr/unet"] = lr_scheduler.get_last_lr()[0] + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + else: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder + + return logs + + def train(args): + session_id = random.randint(0, 2**32) + training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -88,7 +107,8 @@ def train(args): key, value = net_arg.split('=') net_kwargs[key] = value - network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) + # if a new network is added in future, add if ~ then blocks for each network (;'∀') + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return @@ -206,21 +226,26 @@ def train(args): print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { + "ss_session_id": session_id, # random integer indicating which group of epochs the model came from + "ss_training_started_at": training_started_at, # unix timestamp + "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": args.text_encoder_lr, "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data + "ss_num_train_images": train_dataset.num_train_images, # includes repeating "ss_num_reg_images": train_dataset.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, "ss_batch_size_per_device": args.train_batch_size, "ss_total_batch_size": total_batch_size, + "ss_gradient_checkpointing": args.gradient_checkpointing, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_max_train_steps": args.max_train_steps, "ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_scheduler": args.lr_scheduler, "ss_network_module": args.network_module, - "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_network_alpha": args.network_alpha, # some networks may not use this value "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), @@ -232,10 +257,14 @@ def train(args): "ss_random_crop": bool(args.random_crop), "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), - "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT - "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset - "ss_max_bucket_reso": args.max_bucket_reso, - "ss_seed": args.seed + "ss_enable_bucket": bool(train_dataset.enable_bucket), + "ss_min_bucket_reso": train_dataset.min_bucket_reso, + "ss_max_bucket_reso": train_dataset.max_bucket_reso, + "ss_seed": args.seed, + "ss_keep_tokens": args.keep_tokens, + "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), + "ss_training_comment": args.training_comment # will not be updated after training } # uncomment if another network is added @@ -246,6 +275,7 @@ def train(args): sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name @@ -253,6 +283,7 @@ def train(args): vae_name = args.vae if os.path.exists(vae_name): metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name @@ -333,20 +364,20 @@ def train(args): global_step += 1 current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - accelerator.log(logs, step=global_step) - loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() @@ -417,11 +448,15 @@ if __name__ == '__main__': parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_dim", type=int, default=None, help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') + parser.add_argument("--network_alpha", type=float, default=1, + help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する") + parser.add_argument("--training_comment", type=str, default=None, + help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() train(args)