From 6f80fe17fcac026ab85004d5a701835c14f5da84 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 8 Aug 2023 21:03:16 +0900 Subject: [PATCH] fix crashing in saving lora with clipskip --- library/sai_model_spec.py | 12 ++++++++---- library/train_util.py | 20 ++++++++++---------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 88c2cb77..472686ba 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -194,8 +194,8 @@ def build_metadata( # comma separated to tuple if isinstance(reso, str): reso = tuple(map(int, reso.split(","))) - if len(reso) == 1: - reso = (reso[0], reso[0]) + if len(reso) == 1: + reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default if sdxl: @@ -215,7 +215,11 @@ def build_metadata( metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON if timesteps is not None: - metadata["modelspec.timestep_range"] = timesteps + if isinstance(timesteps, str) or isinstance(timesteps, int): + timesteps = (timesteps, timesteps) + if len(timesteps) == 1: + timesteps = (timesteps[0], timesteps[0]) + metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" else: del metadata["modelspec.timestep_range"] @@ -228,7 +232,7 @@ def build_metadata( # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): print(f"Internal error: some metadata values are None: {metadata}") - + return metadata diff --git a/library/train_util.py b/library/train_util.py index dbfe41e8..34e477ed 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2521,7 +2521,7 @@ def get_sai_model_spec( sdxl: bool, lora: bool, textual_inversion: bool, - is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA ): timestamp = time.time() @@ -2546,15 +2546,15 @@ def get_sai_model_spec( lora, textual_inversion, timestamp, - title, - reso, - is_stable_diffusion_ckpt, - args.metadata_author, - args.metadata_description, - args.metadata_license, - args.metadata_tags, - timesteps, - args.clip_skip, # None or int + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, # None or int ) return metadata