diff --git a/fine_tune.py b/fine_tune.py index fa3c81be..1a94870f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -33,7 +33,8 @@ def train(args): train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, + args.dataset_repeats, args.debug_dataset) train_dataset.make_buckets() if args.debug_dataset: @@ -198,7 +199,7 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset.images_count}") + print(f" num examples / サンプル数: {train_dataset.num_train_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -223,8 +224,13 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - latents = batch["latents"].to(accelerator.device) - latents = latents * 0.18215 + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 b_size = latents.shape[0] with torch.set_grad_enabled(args.train_text_encoder): @@ -310,7 +316,7 @@ def train(args): if is_main_process: src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, - save_dtype, epoch, global_step, text_encoder, unet, vae) + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") @@ -324,6 +330,7 @@ if __name__ == '__main__': parser.add_argument("--diffusers_xformers", action='store_true', help='use xformers by diffusers / Diffusersでxformersを使用する') + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") args = parser.parse_args() train(args) diff --git a/library/train_util.py b/library/train_util.py index 98ad10ef..bad954c2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -65,7 +65,7 @@ class BucketBatchIndex(NamedTuple): class BaseDataset(torch.utils.data.Dataset): - def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None: + def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None: super().__init__() self.tokenizer: CLIPTokenizer = tokenizer self.max_token_length = max_token_length @@ -77,6 +77,7 @@ class BaseDataset(torch.utils.data.Dataset): self.flip_aug = flip_aug self.color_aug = color_aug self.debug_dataset = debug_dataset + self.random_crop = random_crop self.token_padding_disabled = False self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -265,8 +266,9 @@ class BaseDataset(torch.utils.data.Dataset): if info.latents_npz is not None: info.latents = self.load_latents_from_npz(info, False) info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) - info.latents_flipped = torch.FloatTensor(info.latents_flipped) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue image = self.load_image(info.absolute_path) @@ -349,6 +351,8 @@ class BaseDataset(torch.utils.data.Dataset): def load_latents_from_npz(self, image_info: ImageInfo, flipped): npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None return np.load(npz_file)['arr_0'] def __len__(self): @@ -444,14 +448,13 @@ class BaseDataset(torch.utils.data.Dataset): class DreamBoothDataset(BaseDataset): def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight - self.random_crop = random_crop self.latents_cache = None self.enable_bucket = enable_bucket @@ -563,9 +566,9 @@ class DreamBoothDataset(BaseDataset): class FineTuningDataset(BaseDataset): - def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: + def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None: super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset) + resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) # メタデータを読み込む if os.path.exists(json_file_name): @@ -639,7 +642,7 @@ class FineTuningDataset(BaseDataset): break sizes.add(image_info.image_size[0]) sizes.add(image_info.image_size[1]) - resos.add(image_info.image_size) + resos.add(tuple(image_info.image_size)) if sizes is None: assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" @@ -708,6 +711,7 @@ def debug_dataset(train_dataset): if k == 27 or example['images'] is None: break + def glob_images(dir, base): img_paths = [] for ext in IMAGE_EXTENSIONS: @@ -986,7 +990,7 @@ def replace_unet_cross_attn_to_xformers(): # endregion -# region utils +# region arguments def add_sd_models_arguments(parser: argparse.ArgumentParser): # for pretrained models @@ -1101,6 +1105,10 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser): parser.add_argument("--use_safetensors", action='store_true', help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") +# endregion + +# region utils + def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): # backward compatibility diff --git a/train_network.py b/train_network.py index 24dfa5b0..e557b1de 100644 --- a/train_network.py +++ b/train_network.py @@ -49,7 +49,8 @@ def train(args): train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, - args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset) + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, + args.dataset_repeats, args.debug_dataset) train_dataset.make_buckets() if args.debug_dataset: @@ -315,7 +316,7 @@ def train(args): saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) - + # end of epoch is_main_process = accelerator.is_main_process