mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix errors in fine tuning
This commit is contained in:
11
fine_tune.py
11
fine_tune.py
@@ -33,7 +33,8 @@ def train(args):
|
|||||||
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
|
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||||
|
args.dataset_repeats, args.debug_dataset)
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
@@ -198,7 +199,7 @@ def train(args):
|
|||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
print("running training / 学習開始")
|
print("running training / 学習開始")
|
||||||
print(f" num examples / サンプル数: {train_dataset.images_count}")
|
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
@@ -223,7 +224,12 @@ def train(args):
|
|||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
|
with torch.no_grad():
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
else:
|
||||||
|
# latentに変換
|
||||||
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
@@ -324,6 +330,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||||
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
||||||
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class BucketBatchIndex(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class BaseDataset(torch.utils.data.Dataset):
|
class BaseDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None:
|
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer: CLIPTokenizer = tokenizer
|
self.tokenizer: CLIPTokenizer = tokenizer
|
||||||
self.max_token_length = max_token_length
|
self.max_token_length = max_token_length
|
||||||
@@ -77,6 +77,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.flip_aug = flip_aug
|
self.flip_aug = flip_aug
|
||||||
self.color_aug = color_aug
|
self.color_aug = color_aug
|
||||||
self.debug_dataset = debug_dataset
|
self.debug_dataset = debug_dataset
|
||||||
|
self.random_crop = random_crop
|
||||||
self.token_padding_disabled = False
|
self.token_padding_disabled = False
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
@@ -265,7 +266,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if info.latents_npz is not None:
|
if info.latents_npz is not None:
|
||||||
info.latents = self.load_latents_from_npz(info, False)
|
info.latents = self.load_latents_from_npz(info, False)
|
||||||
info.latents = torch.FloatTensor(info.latents)
|
info.latents = torch.FloatTensor(info.latents)
|
||||||
info.latents_flipped = self.load_latents_from_npz(info, True)
|
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
||||||
|
if info.latents_flipped is not None:
|
||||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -349,6 +351,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
||||||
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
||||||
|
if npz_file is None:
|
||||||
|
return None
|
||||||
return np.load(npz_file)['arr_0']
|
return np.load(npz_file)['arr_0']
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -444,14 +448,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
class DreamBoothDataset(BaseDataset):
|
class DreamBoothDataset(BaseDataset):
|
||||||
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
||||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||||
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
|
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||||
|
|
||||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||||
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.size = min(self.width, self.height) # 短いほう
|
self.size = min(self.width, self.height) # 短いほう
|
||||||
self.prior_loss_weight = prior_loss_weight
|
self.prior_loss_weight = prior_loss_weight
|
||||||
self.random_crop = random_crop
|
|
||||||
self.latents_cache = None
|
self.latents_cache = None
|
||||||
|
|
||||||
self.enable_bucket = enable_bucket
|
self.enable_bucket = enable_bucket
|
||||||
@@ -563,9 +566,9 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
|
|
||||||
|
|
||||||
class FineTuningDataset(BaseDataset):
|
class FineTuningDataset(BaseDataset):
|
||||||
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None:
|
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
||||||
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
||||||
resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
|
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
||||||
|
|
||||||
# メタデータを読み込む
|
# メタデータを読み込む
|
||||||
if os.path.exists(json_file_name):
|
if os.path.exists(json_file_name):
|
||||||
@@ -639,7 +642,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
break
|
break
|
||||||
sizes.add(image_info.image_size[0])
|
sizes.add(image_info.image_size[0])
|
||||||
sizes.add(image_info.image_size[1])
|
sizes.add(image_info.image_size[1])
|
||||||
resos.add(image_info.image_size)
|
resos.add(tuple(image_info.image_size))
|
||||||
|
|
||||||
if sizes is None:
|
if sizes is None:
|
||||||
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||||
@@ -708,6 +711,7 @@ def debug_dataset(train_dataset):
|
|||||||
if k == 27 or example['images'] is None:
|
if k == 27 or example['images'] is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def glob_images(dir, base):
|
def glob_images(dir, base):
|
||||||
img_paths = []
|
img_paths = []
|
||||||
for ext in IMAGE_EXTENSIONS:
|
for ext in IMAGE_EXTENSIONS:
|
||||||
@@ -986,7 +990,7 @@ def replace_unet_cross_attn_to_xformers():
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
# region utils
|
# region arguments
|
||||||
|
|
||||||
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||||
# for pretrained models
|
# for pretrained models
|
||||||
@@ -1101,6 +1105,10 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument("--use_safetensors", action='store_true',
|
parser.add_argument("--use_safetensors", action='store_true',
|
||||||
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
||||||
|
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# region utils
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||||
# backward compatibility
|
# backward compatibility
|
||||||
|
|||||||
@@ -49,7 +49,8 @@ def train(args):
|
|||||||
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
|
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||||
|
args.dataset_repeats, args.debug_dataset)
|
||||||
train_dataset.make_buckets()
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
|
|||||||
Reference in New Issue
Block a user