Merge branch 'main' into textual_inversion

This commit is contained in:
Kohya S
2023-01-15 16:08:53 +09:00
9 changed files with 182 additions and 106 deletions

View File

@@ -1,8 +1,19 @@
This repository contains training, generation and utility scripts for Stable Diffusion.
**January 9, 2023: Information about the update can be found at [the end of the page](#updates-jan-9-2023).**
## Updates
**20231/1/9: 更新情報が[ページ末尾](#更新情報-202319)にありますのでご覧ください。**
- 15 Jan. 2023, 2023/1/15
- Added ``--max_train_epochs`` and ``--max_data_loader_n_workers`` option for each training script.
- If you specify the number of training epochs with ``--max_train_epochs``, the number of steps is calculated from the number of epochs automatically.
- You can set the number of workers for DataLoader with ``--max_data_loader_n_workers``, default is 8. The lower number may reduce the main memory usage and the time between epochs, but may cause slower dataloading (training).
- ``--max_train_epochs`` と ``--max_data_loader_n_workers`` のオプションが学習スクリプトに追加されました。
- ``--max_train_epochs`` で学習したいエポック数を指定すると、必要なステップ数が自動的に計算され設定されます。
- ``--max_data_loader_n_workers`` で DataLoader の worker 数が指定できますデフォルトは8。値を小さくするとメインメモリの使用量が減り、エポック間の待ち時間も短くなるようです。ただしデータ読み込み学習時間は長くなる可能性があります。
Please read [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) for recent updates.
最近の更新情報は [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) をご覧ください。
##
[日本語版README](./README-ja.md)
@@ -112,79 +123,3 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
# Updates: Jan 9. 2023
All training scripts are updated.
## Breaking Changes
- The ``fine_tuning`` option in ``train_db.py`` is removed. Please use DreamBooth with captions or ``fine_tune.py``.
- The Hypernet feature in ``fine_tune.py`` is removed, will be implemented in ``train_network.py`` in future.
## Features, Improvements and Bug Fixes
### for all script: train_db.py, fine_tune.py and train_network.py
- Added ``output_name`` option. The name of output file can be specified.
- With ``--output_name style1``, the output file is like ``style1_000001.ckpt`` (or ``.safetensors``) for each epoch and ``style1.ckpt`` for last.
- If ommitted (default), same to previous. ``epoch-000001.ckpt`` and ``last.ckpt``.
- Added ``save_last_n_epochs`` option. Keep only latest n files for the checkpoints and the states. Older files are removed. (Thanks to shirayu!)
- If the options are ``--save_every_n_epochs=2 --save_last_n_epochs=3``, in the end of epoch 8, ``epoch-000008.ckpt`` is created and ``epoch-000002.ckpt`` is removed.
### train_db.py
- Added ``max_token_length`` option. Captions can have more than 75 tokens.
### fine_tune.py
- The script now works without .npz files. If .npz is not found, the scripts get the latents with VAE.
- You can omit ``prepare_buckets_latents.py`` in preprocessing. However, it is recommended if you train more than 1 or 2 epochs.
- ``--resolution`` option is required to specify the training resolution.
- Added ``cache_latents`` and ``color_aug`` options.
### train_network.py
- Now ``--gradient_checkpointing`` is effective for U-Net and Text Encoder.
- The memory usage is reduced. The larger batch size is avilable, but the training speed will be slow.
- The training might be possible with 6GB VRAM for dimension=4 with batch size=1.
Documents are not updated now, I will update one by one.
# 更新情報 (2023/1/9)
学習スクリプトを更新しました。
## 削除された機能
- ``train_db.py`` の ``fine_tuning`` は削除されました。キャプション付きの DreamBooth または ``fine_tune.py`` を使ってください。
- ``fine_tune.py`` の Hypernet学習の機能は削除されました。将来的に``train_network.py``に追加される予定です。
## その他の機能追加、バグ修正など
### 学習スクリプトに共通: train_db.py, fine_tune.py and train_network.py
- ``output_name``オプションを追加しました。保存されるモデルファイルの名前を指定できます。
- ``--output_name style1``と指定すると、エポックごとに保存されるファイル名は``style1_000001.ckpt`` (または ``.safetensors``) に、最後に保存されるファイル名は``style1.ckpt``になります。
- 省略時は今までと同じです(``epoch-000001.ckpt``および``last.ckpt``)。
- ``save_last_n_epochs``オプションを追加しました。最新の n ファイル、stateだけ保存し、古いものは削除します。shirayu氏に感謝します。)
- たとえば``--save_every_n_epochs=2 --save_last_n_epochs=3``と指定した時、8エポック目の終了時には、``epoch-000008.ckpt``が保存され``epoch-000002.ckpt``が削除されます。
### train_db.py
- ``max_token_length``オプションを追加しました。75文字を超えるキャプションが使えるようになります。
### fine_tune.py
- .npzファイルがなくても動作するようになりました。.npzファイルがない場合、VAEからlatentsを取得して動作します。
- ``prepare_buckets_latents.py``を前処理で実行しなくても良くなります。ただし事前取得をしておいたほうが、2エポック以上学習する場合にはトータルで高速です。
- この場合、解像度を指定するために``--resolution``オプションが必要です。
- ``cache_latents``と``color_aug``オプションを追加しました。
### train_network.py
- ``--gradient_checkpointing``がU-NetとText Encoderにも有効になりました。
- メモリ消費が減ります。バッチサイズを大きくできますが、トータルでの学習時間は長くなるかもしれません。
- dimension=4のLoRAはバッチサイズ1で6GB VRAMで学習できるかもしれません。
ドキュメントは未更新ですが少しずつ更新の予定です。

View File

@@ -161,10 +161,15 @@ def train(args):
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)

View File

@@ -46,11 +46,13 @@ VGG(
)
"""
import json
from typing import List, Optional, Union
import glob
import importlib
import inspect
import time
import zipfile
from diffusers.utils import deprecate
from diffusers.configuration_utils import FrozenDict
import argparse
@@ -555,6 +557,7 @@ class PipelineLike():
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_scale: float = None,
strength: float = 0.8,
# num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
@@ -673,6 +676,11 @@ class PipelineLike():
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance
if negative_prompt is None:
negative_prompt = [""] * batch_size
@@ -694,8 +702,21 @@ class PipelineLike():
**kwargs,
)
if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
pipe=self,
prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
uncond_prompt=[""]*batch_size,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
)
if do_classifier_free_guidance:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
# CLIP guidanceで使用するembeddingsを取得する
if self.clip_guidance_scale > 0:
@@ -826,22 +847,28 @@ class PipelineLike():
if accepts_eta:
extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((2, 1, 1, 1)) if do_classifier_free_guidance else latents
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if negative_scale is None:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk(num_latent_input) # uncond is real uncond
noise_pred = noise_pred_uncond + guidance_scale * \
(noise_pred_text - noise_pred_uncond) - negative_scale * (noise_pred_negative - noise_pred_uncond)
# perform clip guidance
if self.clip_guidance_scale > 0 or self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0:
text_embeddings_for_guidance = (text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings)
text_embeddings_for_guidance = (text_embeddings.chunk(num_latent_input)[
1] if do_classifier_free_guidance else text_embeddings)
if self.clip_guidance_scale > 0:
noise_pred, latents = self.cond_fn(latents, t, i, text_embeddings_for_guidance, noise_pred,
@@ -1972,6 +1999,14 @@ def main(args):
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if os.path.splitext(network_weight)[1] == '.safetensors':
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight)
network.apply_to(text_encoder, unet)
@@ -2136,12 +2171,12 @@ def main(args):
# 1st stageのバッチを作成して呼び出す
print("process 1st stage1")
batch_1st = []
for params1, (width, height, steps, scale, strength) in batch:
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
width_1st = int(width * args.highres_fix_scale + .5)
height_1st = int(height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, strength)))
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
@@ -2153,7 +2188,8 @@ def main(args):
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
batch = batch_2nd
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width, height, steps, scale, strength) = batch[0]
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
height, steps, scale, negative_scale, strength) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = []
@@ -2229,7 +2265,7 @@ def main(args):
guide_images = guide_images[0]
# generate
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, strength, latents=start_code,
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
if highres_1st and not args.highres_fix_save_1st:
return images
@@ -2246,6 +2282,8 @@ def main(args):
metadata.add_text("scale", str(scale))
if negative_prompt is not None:
metadata.add_text("negative-prompt", negative_prompt)
if negative_scale is not None:
metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt)
@@ -2298,6 +2336,7 @@ def main(args):
width = args.W
height = args.H
scale = args.scale
negative_scale = args.negative_scale
steps = args.steps
seeds = None
strength = 0.8 if args.strength is None else args.strength
@@ -2340,6 +2379,15 @@ def main(args):
print(f"scale: {scale}")
continue
m = re.match(r'nl ([\d\.]+|none|None)', parg, re.IGNORECASE)
if m: # negative scale
if m.group(1).lower() == 'none':
negative_scale = None
else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
continue
m = re.match(r't ([\d\.]+)', parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
@@ -2402,8 +2450,9 @@ def main(args):
print("Use previous image as guide image.")
guide_image = prev_image
# TODO named tupleか何かにする
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
(width, height, steps, scale, strength))
(width, height, steps, scale, negative_scale, strength))
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
process_batch(batch_data, highres_fix)
batch_data.clear()
@@ -2499,6 +2548,8 @@ if __name__ == '__main__':
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数")
parser.add_argument("--highres_fix_save_1st", action='store_true',
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する")
parser.add_argument("--negative_scale", type=float, default=None,
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
args = parser.parse_args()
main(args)

View File

@@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
del new_sd[ANOTHER_POSITION_IDS_KEY]
else:
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
return new_sd
@@ -886,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
vae = AutoencoderKL(**vae_config)
info = vae.load_state_dict(converted_vae_checkpoint)
print("loadint vae:", info)
print("loading vae:", info)
# convert text_model
if v2:
@@ -1105,12 +1105,12 @@ def load_vae(vae_id, dtype):
if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface
vae_sd = torch.load(vae_id, map_location="cpu")
converted_vae_checkpoint = vae_sd
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
else:
# StableDiffusion
vae_model = torch.load(vae_id, map_location="cpu")
vae_sd = vae_model['state_dict']
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
# vae only or full model
full_model = False
@@ -1132,7 +1132,6 @@ def load_vae(vae_id, dtype):
vae.load_state_dict(converted_vae_checkpoint)
return vae
# endregion

View File

@@ -765,6 +765,20 @@ def exists(val):
def default(val, d):
return val if exists(val) else d
def model_hash(filename):
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
m.update(file.read(0x10000))
return m.hexdigest()[0:8]
except FileNotFoundError:
return 'NOFILE'
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
@@ -1051,6 +1065,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")

View File

@@ -135,7 +135,7 @@ def svd(args):
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype)
lora_network_o.save_weights(args.save_to, save_dtype, {})
print(f"LoRA weights are saved to: {args.save_to}")

View File

@@ -92,7 +92,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file
from safetensors.torch import load_file, safe_open
self.weights_sd = load_file(file)
else:
self.weights_sd = torch.load(file, map_location='cpu')
@@ -174,7 +174,10 @@ class LoRANetwork(torch.nn.Module):
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype):
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
@@ -185,6 +188,6 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file
save_file(state_dict, file)
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)

View File

@@ -134,10 +134,15 @@ def train(args):
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)

View File

@@ -126,10 +126,15 @@ def train(args):
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
@@ -194,9 +199,62 @@ def train(args):
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = {
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
"ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
"ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
"ss_resolution": args.resolution,
"ss_clip_skip": args.clip_skip,
"ss_max_token_length": args.max_token_length,
"ss_color_aug": bool(args.color_aug),
"ss_flip_aug": bool(args.flip_aug),
"ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
"ss_max_bucket_reso": args.max_bucket_reso,
"ss_seed": args.seed
}
# uncomment if another network is added
# for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value
if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
if args.vae is not None:
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
metadata = {k: str(v) for k, v in metadata.items()}
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -208,6 +266,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet)
@@ -296,7 +355,7 @@ def train(args):
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype)
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
@@ -311,6 +370,8 @@ def train(args):
# end of epoch
metadata["ss_epoch"] = str(num_train_epochs)
is_main_process = accelerator.is_main_process
if is_main_process:
network = unwrap_model(network)
@@ -330,7 +391,7 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype)
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
print("model saved.")
@@ -341,6 +402,7 @@ if __name__ == '__main__':
train_util.add_dataset_arguments(parser, True, True)
train_util.add_training_arguments(parser, True)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")