unify dataset and save functions

This commit is contained in:
Kohya S
2023-01-05 08:10:22 +09:00
parent 4c35006731
commit f56988b252
5 changed files with 287 additions and 1016 deletions

View File

@@ -1,6 +1,7 @@
import gc
import importlib
import json
import shutil
import time
import argparse
import math
@@ -143,8 +144,6 @@ def train(args):
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
# unet.to(weight_dtype)
# text_encoder.to(weight_dtype)
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
@@ -163,10 +162,14 @@ def train(args):
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.eval()
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.eval()
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
text_encoder.train()
else:
unet.eval()
text_encoder.eval()
network.prepare_grad_etc(text_encoder, unet)
@@ -294,9 +297,29 @@ def train(args):
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
def save_func(file):
unwrap_model(network).save_weights(file, save_dtype)
train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func)
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
unwrap_model(network).save_weights(ckpt_file, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
os.remove(old_ckpt_file)
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1)))
if remove_epoch_no is not None:
state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old):
shutil.rmtree(state_dir_old)
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
@@ -305,14 +328,20 @@ def train(args):
accelerator.end_training()
if args.save_state:
train_util.save_last_state(args, accelerator)
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
def last_save_func(file):
network.save_weights(file, save_dtype)
train_util.save_last_model(args, last_save_func)
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype)
print("model saved.")
if __name__ == '__main__':
@@ -322,6 +351,9 @@ if __name__ == '__main__':
train_util.add_dataset_arguments(parser, True, True)
train_util.add_training_arguments(parser, True)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")