diff --git a/README.md b/README.md index 2495a129..3f3ecf79 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,19 @@ This repository contains training, generation and utility scripts for Stable Diffusion. -**January 9, 2023: Information about the update can be found at [the end of the page](#updates-jan-9-2023).** +## Updates -**20231/1/9: 更新情報が[ページ末尾](#更新情報-202319)にありますのでご覧ください。** +- 15 Jan. 2023, 2023/1/15 + - Added ``--max_train_epochs`` and ``--max_data_loader_n_workers`` option for each training script. + - If you specify the number of training epochs with ``--max_train_epochs``, the number of steps is calculated from the number of epochs automatically. + - You can set the number of workers for DataLoader with ``--max_data_loader_n_workers``, default is 8. The lower number may reduce the main memory usage and the time between epochs, but may cause slower dataloading (training). + - ``--max_train_epochs`` と ``--max_data_loader_n_workers`` のオプションが学習スクリプトに追加されました。 + - ``--max_train_epochs`` で学習したいエポック数を指定すると、必要なステップ数が自動的に計算され設定されます。 + - ``--max_data_loader_n_workers`` で DataLoader の worker 数が指定できます(デフォルトは8)。値を小さくするとメインメモリの使用量が減り、エポック間の待ち時間も短くなるようです。ただしデータ読み込み(学習時間)は長くなる可能性があります。 + +Please read [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) for recent updates. +最近の更新情報は [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) をご覧ください。 + +## [日本語版README](./README-ja.md) @@ -112,79 +123,3 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause - - -# Updates: Jan 9. 2023 - -All training scripts are updated. - -## Breaking Changes - -- The ``fine_tuning`` option in ``train_db.py`` is removed. Please use DreamBooth with captions or ``fine_tune.py``. -- The Hypernet feature in ``fine_tune.py`` is removed, will be implemented in ``train_network.py`` in future. - -## Features, Improvements and Bug Fixes - -### for all script: train_db.py, fine_tune.py and train_network.py - -- Added ``output_name`` option. The name of output file can be specified. - - With ``--output_name style1``, the output file is like ``style1_000001.ckpt`` (or ``.safetensors``) for each epoch and ``style1.ckpt`` for last. - - If ommitted (default), same to previous. ``epoch-000001.ckpt`` and ``last.ckpt``. -- Added ``save_last_n_epochs`` option. Keep only latest n files for the checkpoints and the states. Older files are removed. (Thanks to shirayu!) - - If the options are ``--save_every_n_epochs=2 --save_last_n_epochs=3``, in the end of epoch 8, ``epoch-000008.ckpt`` is created and ``epoch-000002.ckpt`` is removed. - -### train_db.py - -- Added ``max_token_length`` option. Captions can have more than 75 tokens. - -### fine_tune.py - -- The script now works without .npz files. If .npz is not found, the scripts get the latents with VAE. - - You can omit ``prepare_buckets_latents.py`` in preprocessing. However, it is recommended if you train more than 1 or 2 epochs. - - ``--resolution`` option is required to specify the training resolution. -- Added ``cache_latents`` and ``color_aug`` options. - -### train_network.py - -- Now ``--gradient_checkpointing`` is effective for U-Net and Text Encoder. - - The memory usage is reduced. The larger batch size is avilable, but the training speed will be slow. - - The training might be possible with 6GB VRAM for dimension=4 with batch size=1. - -Documents are not updated now, I will update one by one. - -# 更新情報 (2023/1/9) - -学習スクリプトを更新しました。 - -## 削除された機能 -- ``train_db.py`` の ``fine_tuning`` は削除されました。キャプション付きの DreamBooth または ``fine_tune.py`` を使ってください。 -- ``fine_tune.py`` の Hypernet学習の機能は削除されました。将来的に``train_network.py``に追加される予定です。 - -## その他の機能追加、バグ修正など - -### 学習スクリプトに共通: train_db.py, fine_tune.py and train_network.py - -- ``output_name``オプションを追加しました。保存されるモデルファイルの名前を指定できます。 - - ``--output_name style1``と指定すると、エポックごとに保存されるファイル名は``style1_000001.ckpt`` (または ``.safetensors``) に、最後に保存されるファイル名は``style1.ckpt``になります。 - - 省略時は今までと同じです(``epoch-000001.ckpt``および``last.ckpt``)。 -- ``save_last_n_epochs``オプションを追加しました。最新の n ファイル、stateだけ保存し、古いものは削除します。(shirayu氏に感謝します。) - - たとえば``--save_every_n_epochs=2 --save_last_n_epochs=3``と指定した時、8エポック目の終了時には、``epoch-000008.ckpt``が保存され``epoch-000002.ckpt``が削除されます。 - -### train_db.py - -- ``max_token_length``オプションを追加しました。75文字を超えるキャプションが使えるようになります。 - -### fine_tune.py - -- .npzファイルがなくても動作するようになりました。.npzファイルがない場合、VAEからlatentsを取得して動作します。 - - ``prepare_buckets_latents.py``を前処理で実行しなくても良くなります。ただし事前取得をしておいたほうが、2エポック以上学習する場合にはトータルで高速です。 - - この場合、解像度を指定するために``--resolution``オプションが必要です。 -- ``cache_latents``と``color_aug``オプションを追加しました。 - -### train_network.py - -- ``--gradient_checkpointing``がU-NetとText Encoderにも有効になりました。 - - メモリ消費が減ります。バッチサイズを大きくできますが、トータルでの学習時間は長くなるかもしれません。 - - dimension=4のLoRAはバッチサイズ1で6GB VRAMで学習できるかもしれません。 - -ドキュメントは未更新ですが少しずつ更新の予定です。 diff --git a/fine_tune.py b/fine_tune.py index 1a94870f..02f665bd 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -161,10 +161,15 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 208b1b70..4edfe0b2 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -46,11 +46,13 @@ VGG( ) """ +import json from typing import List, Optional, Union import glob import importlib import inspect import time +import zipfile from diffusers.utils import deprecate from diffusers.configuration_utils import FrozenDict import argparse @@ -555,6 +557,7 @@ class PipelineLike(): width: int = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, + negative_scale: float = None, strength: float = 0.8, # num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, @@ -673,6 +676,11 @@ class PipelineLike(): # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + # get unconditional embeddings for classifier free guidance if negative_prompt is None: negative_prompt = [""] * batch_size @@ -694,8 +702,21 @@ class PipelineLike(): **kwargs, ) + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""]*batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + **kwargs, + ) + if do_classifier_free_guidance: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: @@ -826,22 +847,28 @@ class PipelineLike(): if accepts_eta: extra_step_kwargs["eta"] = eta + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((2, 1, 1, 1)) if do_classifier_free_guidance else latents + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond + noise_pred = noise_pred_uncond + guidance_scale * \ + (noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond) # perform clip guidance if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0: - text_embeddings_for_guidance = (text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings) + text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[ + 1] if do_classifier_free_guidance else text_embeddings) if self.clip_guidance_scale > 0: noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred, @@ -1972,6 +1999,14 @@ def main(args): if args.network_weights and i < len(args.network_weights): network_weight = args.network_weights[i] print("load network weights from:", network_weight) + + if os.path.splitext(network_weight)[1] == '.safetensors': + from safetensors.torch import safe_open + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + network.load_weights(network_weight) network.apply_to(text_encoder, unet) @@ -2136,12 +2171,12 @@ def main(args): # 1st stageのバッチを作成して呼び出す print("process 1st stage1") batch_1st = [] - for params1, (width, height, steps, scale, strength) in batch: + for params1, (width, height, steps, scale, negative_scale, strength) in batch: width_1st = int(width * args.highres_fix_scale + .5) height_1st = int(height * args.highres_fix_scale + .5) width_1st = width_1st - width_1st % 32 height_1st = height_1st - height_1st % 32 - batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, strength))) + batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength))) images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する @@ -2153,7 +2188,8 @@ def main(args): batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) batch = batch_2nd - (step_first, _, _, _, init_image, mask_image, _, guide_image), (width, height, steps, scale, strength) = batch[0] + (step_first, _, _, _, init_image, mask_image, _, guide_image), (width, + height, steps, scale, negative_scale, strength) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] @@ -2229,7 +2265,7 @@ def main(args): guide_images = guide_images[0] # generate - images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, strength, latents=start_code, + images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] if highres_1st and not args.highres_fix_save_1st: return images @@ -2246,6 +2282,8 @@ def main(args): metadata.add_text("scale", str(scale)) if negative_prompt is not None: metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) if clip_prompt is not None: metadata.add_text("clip-prompt", clip_prompt) @@ -2298,6 +2336,7 @@ def main(args): width = args.W height = args.H scale = args.scale + negative_scale = args.negative_scale steps = args.steps seeds = None strength = 0.8 if args.strength is None else args.strength @@ -2340,6 +2379,15 @@ def main(args): print(f"scale: {scale}") continue + m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == 'none': + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE) if m: # strength strength = float(m.group(1)) @@ -2402,8 +2450,9 @@ def main(args): print("Use previous image as guide image.") guide_image = prev_image + # TODO named tupleか何かにする b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - (width, height, steps, scale, strength)) + (width, height, steps, scale, negative_scale, strength)) if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? process_batch(batch_data, highres_fix) batch_data.clear() @@ -2499,6 +2548,8 @@ if __name__ == '__main__': help="1st stage steps for highres fix / highres fixの最初のステージのステップ数") parser.add_argument("--highres_fix_save_1st", action='store_true', help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") + parser.add_argument("--negative_scale", type=float, default=None, + help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する") args = parser.parse_args() main(args) diff --git a/library/model_util.py b/library/model_util.py index bc824a12..6a1e656a 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): del new_sd[ANOTHER_POSITION_IDS_KEY] else: position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - + new_sd["text_model.embeddings.position_ids"] = position_ids return new_sd @@ -886,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): vae = AutoencoderKL(**vae_config) info = vae.load_state_dict(converted_vae_checkpoint) - print("loadint vae:", info) + print("loading vae:", info) # convert text_model if v2: @@ -1105,12 +1105,12 @@ def load_vae(vae_id, dtype): if vae_id.endswith(".bin"): # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location="cpu") - converted_vae_checkpoint = vae_sd + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") else: # StableDiffusion - vae_model = torch.load(vae_id, map_location="cpu") - vae_sd = vae_model['state_dict'] + vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) + else torch.load(vae_id, map_location="cpu")) + vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model # vae only or full model full_model = False @@ -1132,7 +1132,6 @@ def load_vae(vae_id, dtype): vae.load_state_dict(converted_vae_checkpoint) return vae - # endregion diff --git a/library/train_util.py b/library/train_util.py index 378c0c29..653e83eb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -765,6 +765,20 @@ def exists(val): def default(val, d): return val if exists(val) else d + +def model_hash(filename): + try: + with open(filename, "rb") as file: + import hashlib + m = hashlib.sha256() + + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return 'NOFILE' + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 @@ -1051,6 +1065,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") + parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする") diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index c882e88f..0a4c3a00 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -135,7 +135,7 @@ def svd(args): if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) - lora_network_o.save_weights(args.save_to, save_dtype) + lora_network_o.save_weights(args.save_to, save_dtype, {}) print(f"LoRA weights are saved to: {args.save_to}") diff --git a/networks/lora.py b/networks/lora.py index 730a6376..3f8244e0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -92,7 +92,7 @@ class LoRANetwork(torch.nn.Module): def load_weights(self, file): if os.path.splitext(file)[1] == '.safetensors': - from safetensors.torch import load_file + from safetensors.torch import load_file, safe_open self.weights_sd = load_file(file) else: self.weights_sd = torch.load(file, map_location='cpu') @@ -174,7 +174,10 @@ class LoRANetwork(torch.nn.Module): def get_trainable_params(self): return self.parameters() - def save_weights(self, file, dtype): + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + state_dict = self.state_dict() if dtype is not None: @@ -185,6 +188,6 @@ class LoRANetwork(torch.nn.Module): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import save_file - save_file(state_dict, file) + save_file(state_dict, file, metadata) else: torch.save(state_dict, file) diff --git a/train_db.py b/train_db.py index 8c9cdb95..bbef3da7 100644 --- a/train_db.py +++ b/train_db.py @@ -134,10 +134,15 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) diff --git a/train_network.py b/train_network.py index 9f292b97..c0a881ad 100644 --- a/train_network.py +++ b/train_network.py @@ -126,10 +126,15 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) @@ -194,9 +199,62 @@ def train(args): print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + metadata = { + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data + "ss_num_reg_images": train_dataset.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_resolution": args.resolution, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_cache_latents": bool(args.cache_latents), + "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT + "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset + "ss_max_bucket_reso": args.max_bucket_reso, + "ss_seed": args.seed + } + + # uncomment if another network is added + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + + if args.pretrained_model_name_or_path is not None: + sd_model_name = args.pretrained_model_name_or_path + if os.path.exists(sd_model_name): + metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) + sd_model_name = os.path.basename(sd_model_name) + metadata["ss_sd_model_name"] = sd_model_name + + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + + metadata = {k: str(v) for k, v in metadata.items()} + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -208,6 +266,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") + metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) @@ -296,7 +355,7 @@ def train(args): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype) + unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as @@ -311,6 +370,8 @@ def train(args): # end of epoch + metadata["ss_epoch"] = str(num_train_epochs) + is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network) @@ -330,7 +391,7 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype) + network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) print("model saved.") @@ -341,6 +402,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True) train_util.add_training_arguments(parser, True) + parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")