fix errors in fine tuning

This commit is contained in:
Kohya S
2023-01-08 21:40:40 +09:00
parent 1945fa186d
commit 6b62c44022
3 changed files with 32 additions and 16 deletions

View File

@@ -33,7 +33,8 @@ def train(args):
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -198,7 +199,7 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset.images_count}")
print(f" num examples / サンプル数: {train_dataset.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
@@ -223,8 +224,13 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
latents = batch["latents"].to(accelerator.device)
latents = latents * 0.18215
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(args.train_text_encoder):
@@ -310,7 +316,7 @@ def train(args):
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors,
save_dtype, epoch, global_step, text_encoder, unet, vae)
save_dtype, epoch, global_step, text_encoder, unet, vae)
print("model saved.")
@@ -324,6 +330,7 @@ if __name__ == '__main__':
parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers / Diffusersでxformersを使用する')
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
args = parser.parse_args()
train(args)

View File

@@ -65,7 +65,7 @@ class BucketBatchIndex(NamedTuple):
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None:
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
super().__init__()
self.tokenizer: CLIPTokenizer = tokenizer
self.max_token_length = max_token_length
@@ -77,6 +77,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.flip_aug = flip_aug
self.color_aug = color_aug
self.debug_dataset = debug_dataset
self.random_crop = random_crop
self.token_padding_disabled = False
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -265,8 +266,9 @@ class BaseDataset(torch.utils.data.Dataset):
if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
info.latents_flipped = self.load_latents_from_npz(info, True)
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
if info.latents_flipped is not None:
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
image = self.load_image(info.absolute_path)
@@ -349,6 +351,8 @@ class BaseDataset(torch.utils.data.Dataset):
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
if npz_file is None:
return None
return np.load(npz_file)['arr_0']
def __len__(self):
@@ -444,14 +448,13 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.random_crop = random_crop
self.latents_cache = None
self.enable_bucket = enable_bucket
@@ -563,9 +566,9 @@ class DreamBoothDataset(BaseDataset):
class FineTuningDataset(BaseDataset):
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None:
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
# メタデータを読み込む
if os.path.exists(json_file_name):
@@ -639,7 +642,7 @@ class FineTuningDataset(BaseDataset):
break
sizes.add(image_info.image_size[0])
sizes.add(image_info.image_size[1])
resos.add(image_info.image_size)
resos.add(tuple(image_info.image_size))
if sizes is None:
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
@@ -708,6 +711,7 @@ def debug_dataset(train_dataset):
if k == 27 or example['images'] is None:
break
def glob_images(dir, base):
img_paths = []
for ext in IMAGE_EXTENSIONS:
@@ -986,7 +990,7 @@ def replace_unet_cross_attn_to_xformers():
# endregion
# region utils
# region arguments
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
@@ -1101,6 +1105,10 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--use_safetensors", action='store_true',
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存するsave_model_as未指定時")
# endregion
# region utils
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
# backward compatibility

View File

@@ -49,7 +49,8 @@ def train(args):
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
train_dataset.make_buckets()
if args.debug_dataset: