From 9f1d3aca2416b1b7ab37cee01ddbca7bcff6bb88 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Jan 2023 20:20:37 +0900 Subject: [PATCH] add save_state_on_train end, fix reg imgs repeats --- library/train_util.py | 59 +++++++++++++++++++++++++++---------------- train_network.py | 9 +++---- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5033a55b..2eb16c00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -76,7 +76,7 @@ class BaseDataset(torch.utils.data.Dataset): self.flip_aug = flip_aug self.color_aug = color_aug self.debug_dataset = debug_dataset - self.padding_disabled = False + self.token_padding_disabled = False self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -102,8 +102,8 @@ class BaseDataset(torch.utils.data.Dataset): self.image_data: dict[str, ImageInfo] = {} - def disable_padding(self): - self.padding_disabled = True + def disable_token_padding(self): + self.token_padding_disabled = True def process_caption(self, caption): if self.shuffle_caption: @@ -412,13 +412,13 @@ class BaseDataset(torch.utils.data.Dataset): caption = self.process_caption(image_info.caption) captions.append(caption) - if not self.padding_disabled: # this option might be omitted in future + if not self.token_padding_disabled: # this option might be omitted in future input_ids_list.append(self.get_input_ids(caption)) example = {} example['loss_weights'] = torch.FloatTensor(loss_weights) - if self.padding_disabled: + if self.token_padding_disabled: # padding=True means pad in the batch example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids else: @@ -540,13 +540,20 @@ class DreamBoothDataset(BaseDataset): if num_reg_images == 0: print("no regularization images / 正則化画像が見つかりませんでした") else: + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する n = 0 + first_loop = True while n < num_train_images: for info in reg_infos: - self.register_image(info) - n += info.num_repeats - if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない + if first_loop: + self.register_image(info) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: break + first_loop = False self.num_reg_images = num_reg_images @@ -1253,7 +1260,6 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs remove_epoch_no = None if saving: - print("saving checkpoint.") os.makedirs(args.output_dir, exist_ok=True) save_func() @@ -1270,6 +1276,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: if save_stable_diffusion_format: def save_sd(): ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae) @@ -1277,6 +1284,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) save_func = save_sd @@ -1284,6 +1292,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: else: def save_du(): out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") os.makedirs(out_dir, exist_ok=True) model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors) @@ -1291,6 +1300,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: def remove_du(old_epoch_no): out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") shutil.rmtree(out_dir_old) save_func = save_du @@ -1298,19 +1308,17 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) if saving and args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) - if remove_epoch_no is not None: - state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) - if os.path.exists(state_dir_old): - shutil.rmtree(state_dir_old) + save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no) -def save_state_on_train_end(args: argparse.Namespace, accelerator): - print("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no): + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + if remove_epoch_no is not None: + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): @@ -1326,12 +1334,19 @@ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_sta model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae) else: - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, model_name) os.makedirs(out_dir, exist_ok=True) + print(f"save trained model as Diffusers to {out_dir}") model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors) + +def save_state_on_train_end(args: argparse.Namespace, accelerator): + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) + + # endregion diff --git a/train_network.py b/train_network.py index bfb2d860..24dfa5b0 100644 --- a/train_network.py +++ b/train_network.py @@ -302,22 +302,19 @@ def train(args): def save_func(): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1))) - if remove_epoch_no is not None: - state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) - if os.path.exists(state_dir_old): - shutil.rmtree(state_dir_old) + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) # end of epoch