mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
unify dataset and save functions
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import gc
|
||||
import importlib
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
import argparse
|
||||
import math
|
||||
@@ -143,8 +144,6 @@ def train(args):
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
print("enable full fp16 training.")
|
||||
# unet.to(weight_dtype)
|
||||
# text_encoder.to(weight_dtype)
|
||||
network.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
@@ -163,10 +162,14 @@ def train(args):
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
unet.eval()
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.eval()
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
text_encoder.train()
|
||||
else:
|
||||
unet.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
@@ -294,9 +297,29 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
def save_func(file):
|
||||
unwrap_model(network).save_weights(file, save_dtype)
|
||||
train_util.save_on_epoch_end(args, accelerator, epoch, num_train_epochs, save_func)
|
||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1)))
|
||||
if remove_epoch_no is not None:
|
||||
state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
# end of epoch
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
@@ -305,14 +328,20 @@ def train(args):
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_last_state(args, accelerator)
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
def last_save_func(file):
|
||||
network.save_weights(file, save_dtype)
|
||||
train_util.save_last_model(args, last_save_func)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = model_name + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
network.save_weights(ckpt_file, save_dtype)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -322,6 +351,9 @@ if __name__ == '__main__':
|
||||
train_util.add_dataset_arguments(parser, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
|
||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user