mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add VAE to meatada, add no_metadata option
This commit is contained in:
@@ -178,6 +178,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def save_weights(self, file, dtype, metadata):
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
if len(metadata) == 0:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
@@ -191,5 +194,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
save_file(state_dict, file, metadata)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
if metadata is not None:
|
||||||
with zipfile.ZipFile(file, "w") as zipf:
|
with zipfile.ZipFile(file, "w") as zipf:
|
||||||
zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata))
|
zipf.writestr(LoRANetwork.METADATA_FILENAME, json.dumps(metadata))
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ def train(args):
|
|||||||
"ss_learning_rate": args.learning_rate,
|
"ss_learning_rate": args.learning_rate,
|
||||||
"ss_text_encoder_lr": args.text_encoder_lr,
|
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||||
"ss_unet_lr": args.unet_lr,
|
"ss_unet_lr": args.unet_lr,
|
||||||
"ss_num_train_images": train_dataset.num_train_images,
|
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
|
||||||
"ss_num_reg_images": train_dataset.num_reg_images,
|
"ss_num_reg_images": train_dataset.num_reg_images,
|
||||||
"ss_num_batches_per_epoch": len(train_dataloader),
|
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||||
"ss_num_epochs": num_train_epochs,
|
"ss_num_epochs": num_train_epochs,
|
||||||
@@ -212,7 +212,8 @@ def train(args):
|
|||||||
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
||||||
"ss_lr_scheduler": args.lr_scheduler,
|
"ss_lr_scheduler": args.lr_scheduler,
|
||||||
"ss_network_module": args.network_module,
|
"ss_network_module": args.network_module,
|
||||||
"ss_network_dim": 4 if args.network_dim is None else args.network_dim,
|
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
||||||
|
"ss_mixed_precision": args.mixed_precision,
|
||||||
"ss_full_fp16": bool(args.full_fp16),
|
"ss_full_fp16": bool(args.full_fp16),
|
||||||
"ss_v2": bool(args.v2),
|
"ss_v2": bool(args.v2),
|
||||||
"ss_resolution": args.resolution,
|
"ss_resolution": args.resolution,
|
||||||
@@ -223,12 +224,16 @@ def train(args):
|
|||||||
"ss_random_crop": bool(args.random_crop),
|
"ss_random_crop": bool(args.random_crop),
|
||||||
"ss_shuffle_caption": bool(args.shuffle_caption),
|
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||||
"ss_cache_latents": bool(args.cache_latents),
|
"ss_cache_latents": bool(args.cache_latents),
|
||||||
"ss_enable_bucket": bool(args.enable_bucket),
|
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
|
||||||
"ss_min_bucket_reso": args.min_bucket_reso,
|
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
|
||||||
"ss_max_bucket_reso": args.max_bucket_reso,
|
"ss_max_bucket_reso": args.max_bucket_reso,
|
||||||
"ss_seed": args.seed
|
"ss_seed": args.seed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# uncomment if another network is added
|
||||||
|
# for key, value in net_kwargs.items():
|
||||||
|
# metadata["ss_arg_" + key] = value
|
||||||
|
|
||||||
if args.pretrained_model_name_or_path is not None:
|
if args.pretrained_model_name_or_path is not None:
|
||||||
sd_model_name = args.pretrained_model_name_or_path
|
sd_model_name = args.pretrained_model_name_or_path
|
||||||
if os.path.exists(sd_model_name):
|
if os.path.exists(sd_model_name):
|
||||||
@@ -236,6 +241,13 @@ def train(args):
|
|||||||
sd_model_name = os.path.basename(sd_model_name)
|
sd_model_name = os.path.basename(sd_model_name)
|
||||||
metadata["ss_sd_model_name"] = sd_model_name
|
metadata["ss_sd_model_name"] = sd_model_name
|
||||||
|
|
||||||
|
if args.vae is not None:
|
||||||
|
vae_name = args.vae
|
||||||
|
if os.path.exists(vae_name):
|
||||||
|
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
||||||
|
vae_name = os.path.basename(vae_name)
|
||||||
|
metadata["ss_vae_name"] = vae_name
|
||||||
|
|
||||||
metadata = {k: str(v) for k, v in metadata.items()}
|
metadata = {k: str(v) for k, v in metadata.items()}
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
@@ -338,7 +350,7 @@ def train(args):
|
|||||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
print(f"saving checkpoint: {ckpt_file}")
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, metadata)
|
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||||
|
|
||||||
def remove_old_func(old_epoch_no):
|
def remove_old_func(old_epoch_no):
|
||||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||||
@@ -374,7 +386,7 @@ def train(args):
|
|||||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
print(f"save trained model to {ckpt_file}")
|
print(f"save trained model to {ckpt_file}")
|
||||||
network.save_weights(ckpt_file, save_dtype, metadata)
|
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||||
print("model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
@@ -385,6 +397,7 @@ if __name__ == '__main__':
|
|||||||
train_util.add_dataset_arguments(parser, True, True)
|
train_util.add_dataset_arguments(parser, True, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
|
||||||
|
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user