fix errors in fine tuning

This commit is contained in:
Kohya S
2023-01-08 21:40:40 +09:00
parent 1945fa186d
commit 6b62c44022
3 changed files with 32 additions and 16 deletions

View File

@@ -65,7 +65,7 @@ class BucketBatchIndex(NamedTuple):
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None:
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
super().__init__()
self.tokenizer: CLIPTokenizer = tokenizer
self.max_token_length = max_token_length
@@ -77,6 +77,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.flip_aug = flip_aug
self.color_aug = color_aug
self.debug_dataset = debug_dataset
self.random_crop = random_crop
self.token_padding_disabled = False
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -265,8 +266,9 @@ class BaseDataset(torch.utils.data.Dataset):
if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
info.latents_flipped = self.load_latents_from_npz(info, True)
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
if info.latents_flipped is not None:
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
image = self.load_image(info.absolute_path)
@@ -349,6 +351,8 @@ class BaseDataset(torch.utils.data.Dataset):
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
if npz_file is None:
return None
return np.load(npz_file)['arr_0']
def __len__(self):
@@ -444,14 +448,13 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.random_crop = random_crop
self.latents_cache = None
self.enable_bucket = enable_bucket
@@ -563,9 +566,9 @@ class DreamBoothDataset(BaseDataset):
class FineTuningDataset(BaseDataset):
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None:
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
# メタデータを読み込む
if os.path.exists(json_file_name):
@@ -639,7 +642,7 @@ class FineTuningDataset(BaseDataset):
break
sizes.add(image_info.image_size[0])
sizes.add(image_info.image_size[1])
resos.add(image_info.image_size)
resos.add(tuple(image_info.image_size))
if sizes is None:
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
@@ -708,6 +711,7 @@ def debug_dataset(train_dataset):
if k == 27 or example['images'] is None:
break
def glob_images(dir, base):
img_paths = []
for ext in IMAGE_EXTENSIONS:
@@ -986,7 +990,7 @@ def replace_unet_cross_attn_to_xformers():
# endregion
# region utils
# region arguments
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
@@ -1101,6 +1105,10 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--use_safetensors", action='store_true',
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時")
# endregion
# region utils
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
# backward compatibility