mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Support image2image mode for Flux
This commit is contained in:
@@ -137,10 +137,9 @@ def do_sample(
|
|||||||
l_pooled: torch.Tensor,
|
l_pooled: torch.Tensor,
|
||||||
t5_out: torch.Tensor,
|
t5_out: torch.Tensor,
|
||||||
txt_ids: torch.Tensor,
|
txt_ids: torch.Tensor,
|
||||||
num_steps: int,
|
timesteps: list[float],
|
||||||
guidance: float,
|
guidance: float,
|
||||||
t5_attn_mask: Optional[torch.Tensor],
|
t5_attn_mask: Optional[torch.Tensor],
|
||||||
is_schnell: bool,
|
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
flux_dtype: torch.dtype,
|
flux_dtype: torch.dtype,
|
||||||
neg_l_pooled: Optional[torch.Tensor] = None,
|
neg_l_pooled: Optional[torch.Tensor] = None,
|
||||||
@@ -148,8 +147,7 @@ def do_sample(
|
|||||||
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
neg_t5_attn_mask: Optional[torch.Tensor] = None,
|
||||||
cfg_scale: Optional[float] = None,
|
cfg_scale: Optional[float] = None,
|
||||||
):
|
):
|
||||||
logger.info(f"num_steps: {num_steps}")
|
logger.info(f"num_steps: {len(timesteps)}")
|
||||||
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
|
||||||
|
|
||||||
# denoise initial noise
|
# denoise initial noise
|
||||||
if accelerator:
|
if accelerator:
|
||||||
@@ -196,6 +194,7 @@ def generate_image(
|
|||||||
t5xxl,
|
t5xxl,
|
||||||
ae,
|
ae,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
image_path: Optional[str],
|
||||||
seed: Optional[int],
|
seed: Optional[int],
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
@@ -203,13 +202,18 @@ def generate_image(
|
|||||||
guidance: float,
|
guidance: float,
|
||||||
negative_prompt: Optional[str],
|
negative_prompt: Optional[str],
|
||||||
cfg_scale: float,
|
cfg_scale: float,
|
||||||
|
strength: float,
|
||||||
):
|
):
|
||||||
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
||||||
logger.info(f"Seed: {seed}")
|
logger.info(f"Seed: {seed}")
|
||||||
|
|
||||||
|
if steps is None:
|
||||||
|
steps = 4 if is_schnell else 50
|
||||||
|
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
||||||
|
timesteps = get_schedule(steps, packed_latent_height * packed_latent_width, shift=not is_schnell)
|
||||||
|
|
||||||
# make first noise with packed shape
|
# make first noise with packed shape
|
||||||
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
||||||
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
|
||||||
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
||||||
noise = torch.randn(
|
noise = torch.randn(
|
||||||
1,
|
1,
|
||||||
@@ -220,14 +224,21 @@ def generate_image(
|
|||||||
generator=torch.Generator(device=device).manual_seed(seed),
|
generator=torch.Generator(device=device).manual_seed(seed),
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare img and img ids
|
if image_path:
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image = torch.tensor(np.array(image), device=device).permute(2, 0, 1).unsqueeze(0)
|
||||||
|
image = torch.nn.functional.interpolate(image, (image_height, image_width))
|
||||||
|
image = image / 255.0 * 2.0 - 1.0
|
||||||
|
image = image.to(device)
|
||||||
|
latents = ae.encode(image)
|
||||||
|
latents = flux_utils.pack_latents(latents)
|
||||||
|
|
||||||
# this is needed only for img2img
|
t_idx = int((1 - strength) * steps)
|
||||||
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
t = timesteps[t_idx]
|
||||||
# if img.shape[0] == 1 and bs > 1:
|
timesteps = timesteps[t_idx:]
|
||||||
# img = repeat(img, "1 ... -> bs ...", bs=bs)
|
|
||||||
|
noise = noise * t + latents * (1 - t)
|
||||||
|
|
||||||
# txt2img only needs img_ids
|
|
||||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
|
||||||
|
|
||||||
# prepare fp8 models
|
# prepare fp8 models
|
||||||
@@ -313,8 +324,6 @@ def generate_image(
|
|||||||
# generate image
|
# generate image
|
||||||
logger.info("Generating image...")
|
logger.info("Generating image...")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
if steps is None:
|
|
||||||
steps = 4 if is_schnell else 50
|
|
||||||
|
|
||||||
img_ids = img_ids.to(device)
|
img_ids = img_ids.to(device)
|
||||||
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
|
||||||
@@ -327,10 +336,9 @@ def generate_image(
|
|||||||
l_pooled,
|
l_pooled,
|
||||||
t5_out,
|
t5_out,
|
||||||
txt_ids,
|
txt_ids,
|
||||||
steps,
|
timesteps,
|
||||||
guidance,
|
guidance,
|
||||||
t5_attn_mask,
|
t5_attn_mask,
|
||||||
is_schnell,
|
|
||||||
device,
|
device,
|
||||||
flux_dtype,
|
flux_dtype,
|
||||||
neg_l_pooled,
|
neg_l_pooled,
|
||||||
@@ -362,13 +370,13 @@ def generate_image(
|
|||||||
|
|
||||||
x = x.clamp(-1, 1)
|
x = x.clamp(-1, 1)
|
||||||
x = x.permute(0, 2, 3, 1)
|
x = x.permute(0, 2, 3, 1)
|
||||||
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
||||||
|
|
||||||
# save image
|
# save image
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
||||||
img.save(output_path)
|
image.save(output_path)
|
||||||
|
|
||||||
logger.info(f"Saved image to {output_path}")
|
logger.info(f"Saved image to {output_path}")
|
||||||
|
|
||||||
@@ -390,6 +398,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--ae", type=str, required=False)
|
parser.add_argument("--ae", type=str, required=False)
|
||||||
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
|
parser.add_argument("--image_path", type=str, default=None)
|
||||||
parser.add_argument("--output_dir", type=str, default=".")
|
parser.add_argument("--output_dir", type=str, default=".")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
|
||||||
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
|
||||||
@@ -401,6 +410,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--guidance", type=float, default=3.5)
|
parser.add_argument("--guidance", type=float, default=3.5)
|
||||||
parser.add_argument("--negative_prompt", type=str, default=None)
|
parser.add_argument("--negative_prompt", type=str, default=None)
|
||||||
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
||||||
|
parser.add_argument("--strength", type=float, default=0.8)
|
||||||
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora_weights",
|
"--lora_weights",
|
||||||
@@ -512,6 +522,7 @@ if __name__ == "__main__":
|
|||||||
t5xxl,
|
t5xxl,
|
||||||
ae,
|
ae,
|
||||||
args.prompt,
|
args.prompt,
|
||||||
|
args.image_path,
|
||||||
args.seed,
|
args.seed,
|
||||||
args.width,
|
args.width,
|
||||||
args.height,
|
args.height,
|
||||||
@@ -519,6 +530,7 @@ if __name__ == "__main__":
|
|||||||
args.guidance,
|
args.guidance,
|
||||||
args.negative_prompt,
|
args.negative_prompt,
|
||||||
args.cfg_scale,
|
args.cfg_scale,
|
||||||
|
args.strength,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# loop for interactive
|
# loop for interactive
|
||||||
@@ -527,11 +539,12 @@ if __name__ == "__main__":
|
|||||||
steps = None
|
steps = None
|
||||||
guidance = args.guidance
|
guidance = args.guidance
|
||||||
cfg_scale = args.cfg_scale
|
cfg_scale = args.cfg_scale
|
||||||
|
strength = args.strength
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print(
|
print(
|
||||||
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --i <image_path> --r <strength> "
|
||||||
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
"--g <guidance> --m <multipliers for LoRA> --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
|
||||||
)
|
)
|
||||||
prompt = input()
|
prompt = input()
|
||||||
if prompt == "":
|
if prompt == "":
|
||||||
@@ -542,6 +555,7 @@ if __name__ == "__main__":
|
|||||||
prompt = options[0].strip()
|
prompt = options[0].strip()
|
||||||
seed = None
|
seed = None
|
||||||
negative_prompt = None
|
negative_prompt = None
|
||||||
|
image_path = None
|
||||||
for opt in options[1:]:
|
for opt in options[1:]:
|
||||||
try:
|
try:
|
||||||
opt = opt.strip()
|
opt = opt.strip()
|
||||||
@@ -553,6 +567,10 @@ if __name__ == "__main__":
|
|||||||
steps = int(opt[1:].strip())
|
steps = int(opt[1:].strip())
|
||||||
elif opt.startswith("d"):
|
elif opt.startswith("d"):
|
||||||
seed = int(opt[1:].strip())
|
seed = int(opt[1:].strip())
|
||||||
|
elif opt.startswith("i"):
|
||||||
|
image_path = opt[1:].strip()
|
||||||
|
elif opt.startswith("r"):
|
||||||
|
strength = float(opt[1:].strip())
|
||||||
elif opt.startswith("g"):
|
elif opt.startswith("g"):
|
||||||
guidance = float(opt[1:].strip())
|
guidance = float(opt[1:].strip())
|
||||||
elif opt.startswith("m"):
|
elif opt.startswith("m"):
|
||||||
@@ -571,6 +589,21 @@ if __name__ == "__main__":
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Invalid option: {opt}, {e}")
|
logger.error(f"Invalid option: {opt}, {e}")
|
||||||
|
|
||||||
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
|
generate_image(
|
||||||
|
model,
|
||||||
|
clip_l,
|
||||||
|
t5xxl,
|
||||||
|
ae,
|
||||||
|
prompt,
|
||||||
|
image_path,
|
||||||
|
seed,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
steps,
|
||||||
|
guidance,
|
||||||
|
negative_prompt,
|
||||||
|
cfg_scale,
|
||||||
|
strength,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Done!")
|
logger.info("Done!")
|
||||||
|
|||||||
Reference in New Issue
Block a user