mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix errors in fine tuning
This commit is contained in:
@@ -65,7 +65,7 @@ class BucketBatchIndex(NamedTuple):
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None:
|
||||
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer: CLIPTokenizer = tokenizer
|
||||
self.max_token_length = max_token_length
|
||||
@@ -77,6 +77,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.flip_aug = flip_aug
|
||||
self.color_aug = color_aug
|
||||
self.debug_dataset = debug_dataset
|
||||
self.random_crop = random_crop
|
||||
self.token_padding_disabled = False
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
@@ -265,8 +266,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if info.latents_npz is not None:
|
||||
info.latents = self.load_latents_from_npz(info, False)
|
||||
info.latents = torch.FloatTensor(info.latents)
|
||||
info.latents_flipped = self.load_latents_from_npz(info, True)
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
||||
if info.latents_flipped is not None:
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
continue
|
||||
|
||||
image = self.load_image(info.absolute_path)
|
||||
@@ -349,6 +351,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
||||
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
||||
if npz_file is None:
|
||||
return None
|
||||
return np.load(npz_file)['arr_0']
|
||||
|
||||
def __len__(self):
|
||||
@@ -444,14 +448,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
|
||||
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
self.random_crop = random_crop
|
||||
self.latents_cache = None
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
@@ -563,9 +566,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
|
||||
class FineTuningDataset(BaseDataset):
|
||||
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None:
|
||||
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||
resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
|
||||
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(json_file_name):
|
||||
@@ -639,7 +642,7 @@ class FineTuningDataset(BaseDataset):
|
||||
break
|
||||
sizes.add(image_info.image_size[0])
|
||||
sizes.add(image_info.image_size[1])
|
||||
resos.add(image_info.image_size)
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
@@ -708,6 +711,7 @@ def debug_dataset(train_dataset):
|
||||
if k == 27 or example['images'] is None:
|
||||
break
|
||||
|
||||
|
||||
def glob_images(dir, base):
|
||||
img_paths = []
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
@@ -986,7 +990,7 @@ def replace_unet_cross_attn_to_xformers():
|
||||
# endregion
|
||||
|
||||
|
||||
# region utils
|
||||
# region arguments
|
||||
|
||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
# for pretrained models
|
||||
@@ -1101,6 +1105,10 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--use_safetensors", action='store_true',
|
||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
||||
|
||||
# endregion
|
||||
|
||||
# region utils
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
# backward compatibility
|
||||
|
||||
Reference in New Issue
Block a user