mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add original size for negative cond
This commit is contained in:
@@ -371,6 +371,8 @@ class PipelineLike:
|
|||||||
width: int = 1024,
|
width: int = 1024,
|
||||||
original_height: int = None,
|
original_height: int = None,
|
||||||
original_width: int = None,
|
original_width: int = None,
|
||||||
|
original_height_negative: int = None,
|
||||||
|
original_width_negative: int = None,
|
||||||
crop_top: int = 0,
|
crop_top: int = 0,
|
||||||
crop_left: int = 0,
|
crop_left: int = 0,
|
||||||
num_inference_steps: int = 50,
|
num_inference_steps: int = 50,
|
||||||
@@ -505,15 +507,22 @@ class PipelineLike:
|
|||||||
original_height = height
|
original_height = height
|
||||||
if original_width is None:
|
if original_width is None:
|
||||||
original_width = width
|
original_width = width
|
||||||
|
if original_height_negative is None:
|
||||||
|
original_height_negative = original_height
|
||||||
|
if original_width_negative is None:
|
||||||
|
original_width_negative = original_width
|
||||||
if crop_top is None:
|
if crop_top is None:
|
||||||
crop_top = 0
|
crop_top = 0
|
||||||
if crop_left is None:
|
if crop_left is None:
|
||||||
crop_left = 0
|
crop_left = 0
|
||||||
emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||||
|
uc_emb1 = sdxl_train_util.get_timestep_embedding(
|
||||||
|
torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256
|
||||||
|
)
|
||||||
emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||||
emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256)
|
emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256)
|
||||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
||||||
uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype)
|
uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
|
||||||
|
|
||||||
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
||||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||||
@@ -1260,6 +1269,8 @@ class BatchDataExt(NamedTuple):
|
|||||||
height: int
|
height: int
|
||||||
original_width: int
|
original_width: int
|
||||||
original_height: int
|
original_height: int
|
||||||
|
original_width_negative: int
|
||||||
|
original_height_negative: int
|
||||||
crop_left: int
|
crop_left: int
|
||||||
crop_top: int
|
crop_top: int
|
||||||
steps: int
|
steps: int
|
||||||
@@ -1820,6 +1831,8 @@ def main(args):
|
|||||||
|
|
||||||
original_width_1st = scale_and_round(ext.original_width)
|
original_width_1st = scale_and_round(ext.original_width)
|
||||||
original_height_1st = scale_and_round(ext.original_height)
|
original_height_1st = scale_and_round(ext.original_height)
|
||||||
|
original_width_negative_1st = scale_and_round(ext.original_width_negative)
|
||||||
|
original_height_negative_1st = scale_and_round(ext.original_height_negative)
|
||||||
crop_left_1st = scale_and_round(ext.crop_left)
|
crop_left_1st = scale_and_round(ext.crop_left)
|
||||||
crop_top_1st = scale_and_round(ext.crop_top)
|
crop_top_1st = scale_and_round(ext.crop_top)
|
||||||
|
|
||||||
@@ -1830,6 +1843,8 @@ def main(args):
|
|||||||
height_1st,
|
height_1st,
|
||||||
original_width_1st,
|
original_width_1st,
|
||||||
original_height_1st,
|
original_height_1st,
|
||||||
|
original_width_negative_1st,
|
||||||
|
original_height_negative_1st,
|
||||||
crop_left_1st,
|
crop_left_1st,
|
||||||
crop_top_1st,
|
crop_top_1st,
|
||||||
args.highres_fix_steps,
|
args.highres_fix_steps,
|
||||||
@@ -1897,6 +1912,8 @@ def main(args):
|
|||||||
height,
|
height,
|
||||||
original_width,
|
original_width,
|
||||||
original_height,
|
original_height,
|
||||||
|
original_width_negative,
|
||||||
|
original_height_negative,
|
||||||
crop_left,
|
crop_left,
|
||||||
crop_top,
|
crop_top,
|
||||||
steps,
|
steps,
|
||||||
@@ -2020,6 +2037,8 @@ def main(args):
|
|||||||
width,
|
width,
|
||||||
original_height,
|
original_height,
|
||||||
original_width,
|
original_width,
|
||||||
|
original_height_negative,
|
||||||
|
original_width_negative,
|
||||||
crop_top,
|
crop_top,
|
||||||
crop_left,
|
crop_left,
|
||||||
steps,
|
steps,
|
||||||
@@ -2060,6 +2079,8 @@ def main(args):
|
|||||||
metadata.add_text("clip-prompt", clip_prompt)
|
metadata.add_text("clip-prompt", clip_prompt)
|
||||||
metadata.add_text("original-height", str(original_height))
|
metadata.add_text("original-height", str(original_height))
|
||||||
metadata.add_text("original-width", str(original_width))
|
metadata.add_text("original-width", str(original_width))
|
||||||
|
metadata.add_text("original-height-negative", str(original_height_negative))
|
||||||
|
metadata.add_text("original-width-negative", str(original_width_negative))
|
||||||
metadata.add_text("crop-top", str(crop_top))
|
metadata.add_text("crop-top", str(crop_top))
|
||||||
metadata.add_text("crop-left", str(crop_left))
|
metadata.add_text("crop-left", str(crop_left))
|
||||||
|
|
||||||
@@ -2123,6 +2144,8 @@ def main(args):
|
|||||||
height = args.H
|
height = args.H
|
||||||
original_width = args.original_width
|
original_width = args.original_width
|
||||||
original_height = args.original_height
|
original_height = args.original_height
|
||||||
|
original_width_negative = args.original_width_negative
|
||||||
|
original_height_negative = args.original_height_negative
|
||||||
crop_top = args.crop_top
|
crop_top = args.crop_top
|
||||||
crop_left = args.crop_left
|
crop_left = args.crop_left
|
||||||
scale = args.scale
|
scale = args.scale
|
||||||
@@ -2165,6 +2188,18 @@ def main(args):
|
|||||||
print(f"original height: {original_height}")
|
print(f"original height: {original_height}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
original_width_negative = int(m.group(1))
|
||||||
|
print(f"original width negative: {original_width_negative}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
original_height_negative = int(m.group(1))
|
||||||
|
print(f"original height negative: {original_height_negative}")
|
||||||
|
continue
|
||||||
|
|
||||||
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
crop_top = int(m.group(1))
|
crop_top = int(m.group(1))
|
||||||
@@ -2301,6 +2336,8 @@ def main(args):
|
|||||||
height,
|
height,
|
||||||
original_width,
|
original_width,
|
||||||
original_height,
|
original_height,
|
||||||
|
original_width_negative,
|
||||||
|
original_height_negative,
|
||||||
crop_left,
|
crop_left,
|
||||||
crop_top,
|
crop_top,
|
||||||
steps,
|
steps,
|
||||||
@@ -2367,6 +2404,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値"
|
"--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--original_height_negative",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--original_width_negative",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値",
|
||||||
|
)
|
||||||
parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値")
|
parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値")
|
||||||
parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値")
|
parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値")
|
||||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
||||||
|
|||||||
Reference in New Issue
Block a user