diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index a9d7fc4f..69c0bd1d 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -371,6 +371,8 @@ class PipelineLike: width: int = 1024, original_height: int = None, original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, crop_top: int = 0, crop_left: int = 0, num_inference_steps: int = 50, @@ -505,15 +507,22 @@ class PipelineLike: original_height = height if original_width is None: original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width if crop_top is None: crop_top = 0 if crop_left is None: crop_left = 0 emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) c_vector = torch.cat([text_pool, c_vector], dim=1) uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) @@ -1260,6 +1269,8 @@ class BatchDataExt(NamedTuple): height: int original_width: int original_height: int + original_width_negative: int + original_height_negative: int crop_left: int crop_top: int steps: int @@ -1820,6 +1831,8 @@ def main(args): original_width_1st = scale_and_round(ext.original_width) original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) crop_left_1st = scale_and_round(ext.crop_left) crop_top_1st = scale_and_round(ext.crop_top) @@ -1830,6 +1843,8 @@ def main(args): height_1st, original_width_1st, original_height_1st, + original_width_negative_1st, + original_height_negative_1st, crop_left_1st, crop_top_1st, args.highres_fix_steps, @@ -1897,6 +1912,8 @@ def main(args): height, original_width, original_height, + original_width_negative, + original_height_negative, crop_left, crop_top, steps, @@ -2020,6 +2037,8 @@ def main(args): width, original_height, original_width, + original_height_negative, + original_width_negative, crop_top, crop_left, steps, @@ -2060,6 +2079,8 @@ def main(args): metadata.add_text("clip-prompt", clip_prompt) metadata.add_text("original-height", str(original_height)) metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) metadata.add_text("crop-top", str(crop_top)) metadata.add_text("crop-left", str(crop_left)) @@ -2123,6 +2144,8 @@ def main(args): height = args.H original_width = args.original_width original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative crop_top = args.crop_top crop_left = args.crop_left scale = args.scale @@ -2165,6 +2188,18 @@ def main(args): print(f"original height: {original_height}") continue + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + print(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + print(f"original height negative: {original_height_negative}") + continue + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) if m: crop_top = int(m.group(1)) @@ -2301,6 +2336,8 @@ def main(args): height, original_width, original_height, + original_width_negative, + original_height_negative, crop_left, crop_top, steps, @@ -2367,6 +2404,18 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")