diff --git a/library/train_util.py b/library/train_util.py index 70af44c9..ade66a38 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -749,16 +749,16 @@ def default(val, d): def model_hash(filename): - try: - with open(filename, "rb") as file: - import hashlib - m = hashlib.sha256() + try: + with open(filename, "rb") as file: + import hashlib + m = hashlib.sha256() - file.seek(0x100000) - m.update(file.read(0x10000)) - return m.hexdigest()[0:8] - except FileNotFoundError: - return 'NOFILE' + file.seek(0x100000) + m.update(file.read(0x10000)) + return m.hexdigest()[0:8] + except FileNotFoundError: + return 'NOFILE' # flash attention forwards and backwards diff --git a/networks/lora.py b/networks/lora.py index 98e8e4a4..77fe26a7 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -178,6 +178,9 @@ class LoRANetwork(torch.nn.Module): return self.parameters() def save_weights(self, file, dtype, metadata): + if len(metadata) == 0: + metadata = None + state_dict = self.state_dict() if dtype is not None: @@ -191,5 +194,6 @@ class LoRANetwork(torch.nn.Module): save_file(state_dict, file, metadata) else: torch.save(state_dict, file) - with zipfile.ZipFile(file, "w") as zipf: - zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) + if metadata is not None: + with zipfile.ZipFile(file, "w") as zipf: + zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata)) diff --git a/train_network.py b/train_network.py index c5593c46..c920c5ed 100644 --- a/train_network.py +++ b/train_network.py @@ -198,37 +198,42 @@ def train(args): print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { - "ss_learning_rate": args.learning_rate, - "ss_text_encoder_lr": args.text_encoder_lr, - "ss_unet_lr": args.unet_lr, - "ss_num_train_images": train_dataset.num_train_images, - "ss_num_reg_images": train_dataset.num_reg_images, - "ss_num_batches_per_epoch": len(train_dataloader), - "ss_num_epochs": num_train_epochs, - "ss_batch_size_per_device": args.train_batch_size, - "ss_total_batch_size": total_batch_size, - "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, - "ss_max_train_steps": args.max_train_steps, - "ss_lr_warmup_steps": args.lr_warmup_steps, - "ss_lr_scheduler": args.lr_scheduler, - "ss_network_module": args.network_module, - "ss_network_dim": 4 if args.network_dim is None else args.network_dim, - "ss_full_fp16": bool(args.full_fp16), - "ss_v2": bool(args.v2), - "ss_resolution": args.resolution, - "ss_clip_skip": args.clip_skip, - "ss_max_token_length": args.max_token_length, - "ss_color_aug": bool(args.color_aug), - "ss_flip_aug": bool(args.flip_aug), - "ss_random_crop": bool(args.random_crop), - "ss_shuffle_caption": bool(args.shuffle_caption), - "ss_cache_latents": bool(args.cache_latents), - "ss_enable_bucket": bool(args.enable_bucket), - "ss_min_bucket_reso": args.min_bucket_reso, - "ss_max_bucket_reso": args.max_bucket_reso, - "ss_seed": args.seed + "ss_learning_rate": args.learning_rate, + "ss_text_encoder_lr": args.text_encoder_lr, + "ss_unet_lr": args.unet_lr, + "ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data + "ss_num_reg_images": train_dataset.num_reg_images, + "ss_num_batches_per_epoch": len(train_dataloader), + "ss_num_epochs": num_train_epochs, + "ss_batch_size_per_device": args.train_batch_size, + "ss_total_batch_size": total_batch_size, + "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, + "ss_max_train_steps": args.max_train_steps, + "ss_lr_warmup_steps": args.lr_warmup_steps, + "ss_lr_scheduler": args.lr_scheduler, + "ss_network_module": args.network_module, + "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim + "ss_mixed_precision": args.mixed_precision, + "ss_full_fp16": bool(args.full_fp16), + "ss_v2": bool(args.v2), + "ss_resolution": args.resolution, + "ss_clip_skip": args.clip_skip, + "ss_max_token_length": args.max_token_length, + "ss_color_aug": bool(args.color_aug), + "ss_flip_aug": bool(args.flip_aug), + "ss_random_crop": bool(args.random_crop), + "ss_shuffle_caption": bool(args.shuffle_caption), + "ss_cache_latents": bool(args.cache_latents), + "ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT + "ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset + "ss_max_bucket_reso": args.max_bucket_reso, + "ss_seed": args.seed } + # uncomment if another network is added + # for key, value in net_kwargs.items(): + # metadata["ss_arg_" + key] = value + if args.pretrained_model_name_or_path is not None: sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): @@ -236,6 +241,13 @@ def train(args): sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name + if args.vae is not None: + vae_name = args.vae + if os.path.exists(vae_name): + metadata["ss_vae_hash"] = train_util.model_hash(vae_name) + vae_name = os.path.basename(vae_name) + metadata["ss_vae_name"] = vae_name + metadata = {k: str(v) for k, v in metadata.items()} progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") @@ -338,7 +350,7 @@ def train(args): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") - unwrap_model(network).save_weights(ckpt_file, save_dtype, metadata) + unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as @@ -374,7 +386,7 @@ def train(args): ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model to {ckpt_file}") - network.save_weights(ckpt_file, save_dtype, metadata) + network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) print("model saved.") @@ -385,6 +397,7 @@ if __name__ == '__main__': train_util.add_dataset_arguments(parser, True, True) train_util.add_training_arguments(parser, True) + parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")