mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix convert_diffusers20_original_sd.py and add metadata & variant options
This commit is contained in:
@@ -23,7 +23,7 @@ def convert(args):
|
|||||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||||
|
|
||||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||||
# assert (
|
# assert (
|
||||||
# is_save_ckpt or args.reference_model is not None
|
# is_save_ckpt or args.reference_model is not None
|
||||||
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||||
@@ -37,7 +37,7 @@ def convert(args):
|
|||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
|
||||||
else:
|
else:
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
|
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
|
||||||
)
|
)
|
||||||
text_encoder = pipe.text_encoder
|
text_encoder = pipe.text_encoder
|
||||||
vae = pipe.vae
|
vae = pipe.vae
|
||||||
@@ -57,7 +57,7 @@ def convert(args):
|
|||||||
if is_save_ckpt:
|
if is_save_ckpt:
|
||||||
original_model = args.model_to_load if is_load_ckpt else None
|
original_model = args.model_to_load if is_load_ckpt else None
|
||||||
key_count = model_util.save_stable_diffusion_checkpoint(
|
key_count = model_util.save_stable_diffusion_checkpoint(
|
||||||
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
|
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, None if args.metadata is None else eval(args.metadata), save_dtype=save_dtype, vae=vae
|
||||||
)
|
)
|
||||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||||
else:
|
else:
|
||||||
@@ -65,7 +65,7 @@ def convert(args):
|
|||||||
model_util.save_diffusers_checkpoint(
|
model_util.save_diffusers_checkpoint(
|
||||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||||
)
|
)
|
||||||
print(f"model saved.")
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
@@ -99,6 +99,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--variant",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="variant: Diffusers variant to load. Example: fp16",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--reference_model",
|
"--reference_model",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user