mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
set minimum metadata even with no_metadata
This commit is contained in:
@@ -427,10 +427,13 @@ def train(args):
|
||||
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
||||
})
|
||||
|
||||
# add extra args
|
||||
if args.network_args:
|
||||
for key, value in net_kwargs.items():
|
||||
metadata["ss_arg_" + key] = value
|
||||
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
||||
# for key, value in net_kwargs.items():
|
||||
# metadata["ss_arg_" + key] = value
|
||||
|
||||
# model name and hash
|
||||
if args.pretrained_model_name_or_path is not None:
|
||||
sd_model_name = args.pretrained_model_name_or_path
|
||||
if os.path.exists(sd_model_name):
|
||||
@@ -449,6 +452,13 @@ def train(args):
|
||||
|
||||
metadata = {k: str(v) for k, v in metadata.items()}
|
||||
|
||||
# make minimum metadata for filtering
|
||||
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
||||
minimum_metadata = {}
|
||||
for key in minimum_keys:
|
||||
if key in metadata:
|
||||
minimum_metadata[key] = metadata[key]
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
@@ -564,7 +574,7 @@ def train(args):
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||
@@ -603,7 +613,7 @@ def train(args):
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user