diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 94ec8179..690d111e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2300,7 +2300,10 @@ def main(args): else: data = torch.load(embeds_file, map_location="cpu") + if "string_to_param" in data: + data = data["string_to_param"] embeds = next(iter(data.values())) + if type(embeds) != torch.Tensor: raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")