fix errors in fine tuning

This commit is contained in:
Kohya S
2023-01-08 21:40:40 +09:00
parent 1945fa186d
commit 6b62c44022
3 changed files with 32 additions and 16 deletions

View File

@@ -49,7 +49,8 @@ def train(args):
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -315,7 +316,7 @@ def train(args):
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no)
# end of epoch
is_main_process = accelerator.is_main_process