mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update to add grad_ckpting etc to metadata
This commit is contained in:
@@ -80,8 +80,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.debug_dataset = debug_dataset
|
self.debug_dataset = debug_dataset
|
||||||
self.random_crop = random_crop
|
self.random_crop = random_crop
|
||||||
self.token_padding_disabled = False
|
self.token_padding_disabled = False
|
||||||
self.dataset_dirs = {}
|
self.dataset_dirs_info = {}
|
||||||
self.reg_dataset_dirs = {}
|
self.reg_dataset_dirs_info = {}
|
||||||
|
self.enable_bucket = False
|
||||||
|
self.min_bucket_reso = None
|
||||||
|
self.max_bucket_reso = None
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
@@ -466,6 +469,8 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||||
|
self.min_bucket_reso = min_bucket_reso
|
||||||
|
self.max_bucket_reso = max_bucket_reso
|
||||||
else:
|
else:
|
||||||
self.bucket_resos = [(self.width, self.height)]
|
self.bucket_resos = [(self.width, self.height)]
|
||||||
self.bucket_aspect_ratios = [self.width / self.height]
|
self.bucket_aspect_ratios = [self.width / self.height]
|
||||||
@@ -526,7 +531,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||||
self.register_image(info)
|
self.register_image(info)
|
||||||
self.dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
print(f"{num_train_images} train images with repeating.")
|
print(f"{num_train_images} train images with repeating.")
|
||||||
self.num_train_images = num_train_images
|
self.num_train_images = num_train_images
|
||||||
|
|
||||||
@@ -543,7 +548,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
for img_path, caption in zip(img_paths, captions):
|
for img_path, caption in zip(img_paths, captions):
|
||||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||||
reg_infos.append(info)
|
reg_infos.append(info)
|
||||||
self.reg_dataset_dirs[dir] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||||
|
|
||||||
print(f"{num_reg_images} reg images.")
|
print(f"{num_reg_images} reg images.")
|
||||||
if num_train_images < num_reg_images:
|
if num_train_images < num_reg_images:
|
||||||
@@ -616,6 +621,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.num_train_images = len(metadata) * dataset_repeats
|
self.num_train_images = len(metadata) * dataset_repeats
|
||||||
self.num_reg_images = 0
|
self.num_reg_images = 0
|
||||||
|
|
||||||
|
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||||
|
|
||||||
# check existence of all npz files
|
# check existence of all npz files
|
||||||
if not self.color_aug:
|
if not self.color_aug:
|
||||||
npz_any = False
|
npz_any = False
|
||||||
@@ -658,6 +665,8 @@ class FineTuningDataset(BaseDataset):
|
|||||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||||
|
self.min_bucket_reso = min_bucket_reso
|
||||||
|
self.max_bucket_reso = max_bucket_reso
|
||||||
else:
|
else:
|
||||||
self.bucket_resos = [(self.width, self.height)]
|
self.bucket_resos = [(self.width, self.height)]
|
||||||
self.bucket_aspect_ratios = [self.width / self.height]
|
self.bucket_aspect_ratios = [self.width / self.height]
|
||||||
@@ -670,6 +679,9 @@ class FineTuningDataset(BaseDataset):
|
|||||||
self.bucket_resos.sort()
|
self.bucket_resos.sort()
|
||||||
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
|
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
|
||||||
|
|
||||||
|
self.min_bucket_reso = min([min(reso) for reso in resos])
|
||||||
|
self.max_bucket_reso = max([max(reso) for reso in resos])
|
||||||
|
|
||||||
def image_key_to_npz_file(self, image_key):
|
def image_key_to_npz_file(self, image_key):
|
||||||
base_name = os.path.splitext(image_key)[0]
|
base_name = os.path.splitext(image_key)[0]
|
||||||
npz_file_norm = base_name + '.npz'
|
npz_file_norm = base_name + '.npz'
|
||||||
@@ -1046,7 +1058,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
||||||
|
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||||
parser.add_argument("--save_state", action="store_true",
|
parser.add_argument("--save_state", action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
@@ -1065,8 +1078,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
|
|
||||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||||
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
parser.add_argument("--max_train_epochs", type=int, default=None,
|
||||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||||
|
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
||||||
|
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||||
|
|||||||
@@ -223,6 +223,7 @@ def train(args):
|
|||||||
"ss_num_epochs": num_train_epochs,
|
"ss_num_epochs": num_train_epochs,
|
||||||
"ss_batch_size_per_device": args.train_batch_size,
|
"ss_batch_size_per_device": args.train_batch_size,
|
||||||
"ss_total_batch_size": total_batch_size,
|
"ss_total_batch_size": total_batch_size,
|
||||||
|
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
||||||
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||||
"ss_max_train_steps": args.max_train_steps,
|
"ss_max_train_steps": args.max_train_steps,
|
||||||
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
||||||
@@ -240,13 +241,13 @@ def train(args):
|
|||||||
"ss_random_crop": bool(args.random_crop),
|
"ss_random_crop": bool(args.random_crop),
|
||||||
"ss_shuffle_caption": bool(args.shuffle_caption),
|
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||||
"ss_cache_latents": bool(args.cache_latents),
|
"ss_cache_latents": bool(args.cache_latents),
|
||||||
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
|
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
||||||
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
|
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
||||||
"ss_max_bucket_reso": args.max_bucket_reso,
|
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
||||||
"ss_seed": args.seed,
|
"ss_seed": args.seed,
|
||||||
"ss_keep_tokens": args.keep_tokens,
|
"ss_keep_tokens": args.keep_tokens,
|
||||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs),
|
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs),
|
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||||
}
|
}
|
||||||
|
|
||||||
# uncomment if another network is added
|
# uncomment if another network is added
|
||||||
|
|||||||
Reference in New Issue
Block a user