This commit is contained in:
Grigory Reznikov
2026-02-19 00:02:21 +01:00
committed by GitHub

View File

@@ -145,10 +145,9 @@ def do_sample(
l_pooled: Optional[torch.Tensor], l_pooled: Optional[torch.Tensor],
t5_out: torch.Tensor, t5_out: torch.Tensor,
txt_ids: torch.Tensor, txt_ids: torch.Tensor,
num_steps: int, timesteps: list[float],
guidance: float, guidance: float,
t5_attn_mask: Optional[torch.Tensor], t5_attn_mask: Optional[torch.Tensor],
is_schnell: bool,
device: torch.device, device: torch.device,
flux_dtype: torch.dtype, flux_dtype: torch.dtype,
neg_l_pooled: Optional[torch.Tensor] = None, neg_l_pooled: Optional[torch.Tensor] = None,
@@ -156,8 +155,7 @@ def do_sample(
neg_t5_attn_mask: Optional[torch.Tensor] = None, neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None, cfg_scale: Optional[float] = None,
): ):
logger.info(f"num_steps: {num_steps}") logger.info(f"num_steps: {len(timesteps)}")
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
# denoise initial noise # denoise initial noise
if accelerator: if accelerator:
@@ -204,6 +202,7 @@ def generate_image(
t5xxl, t5xxl,
ae, ae,
prompt: str, prompt: str,
image_path: Optional[str],
seed: Optional[int], seed: Optional[int],
image_width: int, image_width: int,
image_height: int, image_height: int,
@@ -211,13 +210,18 @@ def generate_image(
guidance: float, guidance: float,
negative_prompt: Optional[str], negative_prompt: Optional[str],
cfg_scale: float, cfg_scale: float,
strength: float,
): ):
seed = seed if seed is not None else random.randint(0, 2**32 - 1) seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}") logger.info(f"Seed: {seed}")
if steps is None:
steps = 4 if is_schnell else 50
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
timesteps = get_schedule(steps, packed_latent_height * packed_latent_width, shift=not is_schnell)
# make first noise with packed shape # make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise_dtype = torch.float32 if is_fp8(dtype) else dtype
noise = torch.randn( noise = torch.randn(
1, 1,
@@ -228,14 +232,21 @@ def generate_image(
generator=torch.Generator(device=device).manual_seed(seed), generator=torch.Generator(device=device).manual_seed(seed),
) )
# prepare img and img ids if image_path:
image = Image.open(image_path).convert("RGB")
image = torch.tensor(np.array(image), device=device).permute(2, 0, 1).unsqueeze(0)
image = torch.nn.functional.interpolate(image, (image_height, image_width))
image = image / 255.0 * 2.0 - 1.0
image = image.to(device)
latents = ae.encode(image)
latents = flux_utils.pack_latents(latents)
# this is needed only for img2img t_idx = int((1 - strength) * steps)
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) t = timesteps[t_idx]
# if img.shape[0] == 1 and bs > 1: timesteps = timesteps[t_idx:]
# img = repeat(img, "1 ... -> bs ...", bs=bs)
noise = noise * t + latents * (1 - t)
# txt2img only needs img_ids
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
# prepare fp8 models # prepare fp8 models
@@ -326,8 +337,6 @@ def generate_image(
# generate image # generate image
logger.info("Generating image...") logger.info("Generating image...")
model = model.to(device) model = model.to(device)
if steps is None:
steps = 4 if is_schnell else 50
img_ids = img_ids.to(device) img_ids = img_ids.to(device)
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
@@ -341,10 +350,9 @@ def generate_image(
l_pooled, l_pooled,
t5_out, t5_out,
txt_ids, txt_ids,
steps, timesteps,
guidance, guidance,
t5_attn_mask, t5_attn_mask,
is_schnell,
device, device,
flux_dtype, flux_dtype,
neg_l_pooled, neg_l_pooled,
@@ -376,13 +384,13 @@ def generate_image(
x = x.clamp(-1, 1) x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1) x = x.permute(0, 2, 3, 1)
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# save image # save image
output_dir = args.output_dir output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
img.save(output_path) image.save(output_path)
logger.info(f"Saved image to {output_path}") logger.info(f"Saved image to {output_path}")
@@ -405,6 +413,7 @@ if __name__ == "__main__":
parser.add_argument("--ae", type=str, required=False) parser.add_argument("--ae", type=str, required=False)
parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat") parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=".") parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
@@ -416,6 +425,7 @@ if __name__ == "__main__":
parser.add_argument("--guidance", type=float, default=3.5) parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None) parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--strength", type=float, default=0.8)
parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument( parser.add_argument(
"--lora_weights", "--lora_weights",
@@ -532,6 +542,7 @@ if __name__ == "__main__":
t5xxl, t5xxl,
ae, ae,
args.prompt, args.prompt,
args.image_path,
args.seed, args.seed,
args.width, args.width,
args.height, args.height,
@@ -539,6 +550,7 @@ if __name__ == "__main__":
args.guidance, args.guidance,
args.negative_prompt, args.negative_prompt,
args.cfg_scale, args.cfg_scale,
args.strength,
) )
else: else:
# loop for interactive # loop for interactive
@@ -547,11 +559,12 @@ if __name__ == "__main__":
steps = None steps = None
guidance = args.guidance guidance = args.guidance
cfg_scale = args.cfg_scale cfg_scale = args.cfg_scale
strength = args.strength
while True: while True:
print( print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>" "Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --i <image_path> --r <strength> "
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>" "--g <guidance> --m <multipliers for LoRA> --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
) )
prompt = input() prompt = input()
if prompt == "": if prompt == "":
@@ -562,6 +575,7 @@ if __name__ == "__main__":
prompt = options[0].strip() prompt = options[0].strip()
seed = None seed = None
negative_prompt = None negative_prompt = None
image_path = None
for opt in options[1:]: for opt in options[1:]:
try: try:
opt = opt.strip() opt = opt.strip()
@@ -573,6 +587,10 @@ if __name__ == "__main__":
steps = int(opt[1:].strip()) steps = int(opt[1:].strip())
elif opt.startswith("d"): elif opt.startswith("d"):
seed = int(opt[1:].strip()) seed = int(opt[1:].strip())
elif opt.startswith("i"):
image_path = opt[1:].strip()
elif opt.startswith("r"):
strength = float(opt[1:].strip())
elif opt.startswith("g"): elif opt.startswith("g"):
guidance = float(opt[1:].strip()) guidance = float(opt[1:].strip())
elif opt.startswith("m"): elif opt.startswith("m"):
@@ -591,6 +609,21 @@ if __name__ == "__main__":
except ValueError as e: except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}") logger.error(f"Invalid option: {opt}, {e}")
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) generate_image(
model,
clip_l,
t5xxl,
ae,
prompt,
image_path,
seed,
width,
height,
steps,
guidance,
negative_prompt,
cfg_scale,
strength,
)
logger.info("Done!") logger.info("Done!")