diff --git a/library/train_util.py b/library/train_util.py index 94175b98..0fdbadc1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -80,8 +80,11 @@ class BaseDataset(torch.utils.data.Dataset): self.debug_dataset = debug_dataset self.random_crop = random_crop self.token_padding_disabled = False - self.dataset_dirs = {} - self.reg_dataset_dirs = {} + self.dataset_dirs_info = {} + self.reg_dataset_dirs_info = {} + self.enable_bucket = False + self.min_bucket_reso = None + self.max_bucket_reso = None self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 @@ -466,6 +469,8 @@ class DreamBoothDataset(BaseDataset): assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( (self.width, self.height), min_bucket_reso, max_bucket_reso) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso else: self.bucket_resos = [(self.width, self.height)] self.bucket_aspect_ratios = [self.width / self.height] @@ -526,7 +531,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) - self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)} + self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images @@ -543,7 +548,7 @@ class DreamBoothDataset(BaseDataset): for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, True, img_path) reg_infos.append(info) - self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)} + self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_reg_images} reg images.") if num_train_images < num_reg_images: @@ -616,6 +621,8 @@ class FineTuningDataset(BaseDataset): self.num_train_images = len(metadata) * dataset_repeats self.num_reg_images = 0 + self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} + # check existence of all npz files if not self.color_aug: npz_any = False @@ -658,6 +665,8 @@ class FineTuningDataset(BaseDataset): assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( (self.width, self.height), min_bucket_reso, max_bucket_reso) + self.min_bucket_reso = min_bucket_reso + self.max_bucket_reso = max_bucket_reso else: self.bucket_resos = [(self.width, self.height)] self.bucket_aspect_ratios = [self.width / self.height] @@ -670,6 +679,9 @@ class FineTuningDataset(BaseDataset): self.bucket_resos.sort() self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos] + self.min_bucket_reso = min([min(reso) for reso in resos]) + self.max_bucket_reso = max([max(reso) for reso in resos]) + def image_key_to_npz_file(self, image_key): base_name = os.path.splitext(image_key)[0] npz_file_norm = base_name + '.npz' @@ -1046,7 +1058,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--save_every_n_epochs", type=int, default=None, help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") - parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") + parser.add_argument("--save_last_n_epochs_state", type=int, default=None, + help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)") parser.add_argument("--save_state", action="store_true", help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") @@ -1065,8 +1078,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") - parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") + parser.add_argument("--max_train_epochs", type=int, default=None, + help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)") + parser.add_argument("--max_data_loader_n_workers", type=int, default=8, + help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする") diff --git a/train_network.py b/train_network.py index 73370ee2..8b4e008b 100644 --- a/train_network.py +++ b/train_network.py @@ -223,6 +223,7 @@ def train(args): "ss_num_epochs": num_train_epochs, "ss_batch_size_per_device": args.train_batch_size, "ss_total_batch_size": total_batch_size, + "ss_gradient_checkpointing": args.gradient_checkpointing, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_max_train_steps": args.max_train_steps, "ss_lr_warmup_steps": args.lr_warmup_steps, @@ -240,13 +241,13 @@ def train(args): "ss_random_crop": bool(args.random_crop), "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), - "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT - "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset - "ss_max_bucket_reso": args.max_bucket_reso, + "ss_enable_bucket": bool(train_dataset.enable_bucket), + "ss_min_bucket_reso": train_dataset.min_bucket_reso, + "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, "ss_keep_tokens": args.keep_tokens, - "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs), - "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs), + "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), + "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), } # uncomment if another network is added