diff --git a/train_network.py b/train_network.py index b01ec117..4c588641 100644 --- a/train_network.py +++ b/train_network.py @@ -275,9 +275,11 @@ def train(args): "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), "ss_enable_bucket": bool(train_dataset.enable_bucket), + "ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale), "ss_min_bucket_reso": train_dataset.min_bucket_reso, "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, + "ss_lowram": args.lowram, "ss_keep_tokens": args.keep_tokens, "ss_noise_offset": args.noise_offset, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), @@ -286,7 +288,12 @@ def train(args): "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), - "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else "") + "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), + "ss_caption_dropout_rate": args.caption_dropout_rate, + "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs, + "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate, + "ss_face_crop_aug_range": args.face_crop_aug_range, + "ss_prior_loss_weight": args.prior_loss_weight, } # uncomment if another network is added @@ -422,6 +429,7 @@ def train(args): def save_func(): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) + metadata["ss_training_finished_at"] = str(time.time()) print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) @@ -439,6 +447,7 @@ def train(args): # end of epoch metadata["ss_epoch"] = str(num_train_epochs) + metadata["ss_training_finished_at"] = str(time.time()) is_main_process = accelerator.is_main_process if is_main_process: