add save_state_on_train end, fix reg imgs repeats

This commit is contained in:
Kohya S
2023-01-07 20:20:37 +09:00
parent 2efced0a9a
commit 9f1d3aca24
2 changed files with 40 additions and 28 deletions

View File

@@ -76,7 +76,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.flip_aug = flip_aug
self.color_aug = color_aug
self.debug_dataset = debug_dataset
self.padding_disabled = False
self.token_padding_disabled = False
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -102,8 +102,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_data: dict[str, ImageInfo] = {}
def disable_padding(self):
self.padding_disabled = True
def disable_token_padding(self):
self.token_padding_disabled = True
def process_caption(self, caption):
if self.shuffle_caption:
@@ -412,13 +412,13 @@ class BaseDataset(torch.utils.data.Dataset):
caption = self.process_caption(image_info.caption)
captions.append(caption)
if not self.padding_disabled: # this option might be omitted in future
if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
example = {}
example['loss_weights'] = torch.FloatTensor(loss_weights)
if self.padding_disabled:
if self.token_padding_disabled:
# padding=True means pad in the batch
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
else:
@@ -540,13 +540,20 @@ class DreamBoothDataset(BaseDataset):
if num_reg_images == 0:
print("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算するどうせ大した数ではないのでループで処理する
n = 0
first_loop = True
while n < num_train_images:
for info in reg_infos:
self.register_image(info)
n += info.num_repeats
if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない
if first_loop:
self.register_image(info)
n += info.num_repeats
else:
info.num_repeats += 1
n += 1
if n >= num_train_images:
break
first_loop = False
self.num_reg_images = num_reg_images
@@ -1253,7 +1260,6 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
remove_epoch_no = None
if saving:
print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True)
save_func()
@@ -1270,6 +1276,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
if save_stable_diffusion_format:
def save_sd():
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
src_path, epoch_no, global_step, save_dtype, vae)
@@ -1277,6 +1284,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
save_func = save_sd
@@ -1284,6 +1292,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
else:
def save_du():
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
print(f"saving model: {out_dir}")
os.makedirs(out_dir, exist_ok=True)
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
src_path, vae=vae, use_safetensors=use_safetensors)
@@ -1291,6 +1300,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
if os.path.exists(out_dir_old):
print(f"removing old model: {out_dir_old}")
shutil.rmtree(out_dir_old)
save_func = save_du
@@ -1298,19 +1308,17 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
if saving and args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
if remove_epoch_no is not None:
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old):
shutil.rmtree(state_dir_old)
save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no):
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
if remove_epoch_no is not None:
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old):
print(f"removing old state: {state_dir_old}")
shutil.rmtree(state_dir_old)
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
@@ -1326,12 +1334,19 @@ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_sta
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
src_path, epoch, global_step, save_dtype, vae)
else:
print(f"save trained model as Diffusers to {args.output_dir}")
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
print(f"save trained model as Diffusers to {out_dir}")
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
src_path, vae=vae, use_safetensors=use_safetensors)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# endregion

View File

@@ -302,22 +302,19 @@ def train(args):
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1)))
if remove_epoch_no is not None:
state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old):
shutil.rmtree(state_dir_old)
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no)
# end of epoch