diff --git a/gen_img.py b/gen_img.py index 4395b790..a00f8124 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1927,18 +1927,6 @@ def main(args): ) pipe.set_gradual_latent(gradual_latent) - # Flexible Zero Slicing - if args.flexible_zero_slicing_mask: - # mask 画像は背景 255、zero にする部分 0 とする - print(f"loading Flexible Zero Slicing mask") - fz_mask = Image.open(args.flexible_zero_slicing_mask).convert("RGB") - fz_mask = np.array(fz_mask).astype(np.float32) / 255.0 - fz_mask = fz_mask[:, :, 0] - fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device) - - # only for sdxl - unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps) - # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -2131,6 +2119,33 @@ def main(args): l.extend([im] * args.images_per_prompt) mask_images = l + # Flexible Zero Slicing + if args.flexible_zero_slicing_depth is not None: + # CV2 が必要 + import cv2 + + # mask 画像は背景 255、zero にする部分 0 とする + np_mask = np.array(mask_images[0].convert("RGB")) + fz_mask = np.full(np_mask.shape, 255, dtype=np.uint8) + + # 各チャンネルに対して処理 + for i in range(3): + # チャンネルを抽出 + channel = np_mask[:, :, i] + + # 輪郭を検出 + contours, _ = cv2.findContours(channel, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # 輪郭を新しい配列に描画 + cv2.drawContours(fz_mask, contours, -1, (0, 0, 0), 1) + + fz_mask = fz_mask.astype(np.float32) / 255.0 + fz_mask = fz_mask[:, :, 0] + fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device) + + # only for sdxl + unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps) + # 画像サイズにオプション指定があるときはリサイズする if args.W is not None and args.H is not None: # highres fix を考慮に入れる @@ -3332,12 +3347,12 @@ def setup_parser() -> argparse.ArgumentParser: + " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨", ) - parser.add_argument( - "--flexible_zero_slicing_mask", - type=str, - default=None, - help="mask for flexible zero slicing / flexible zero slicingのマスク", - ) + # parser.add_argument( + # "--flexible_zero_slicing_mask", + # type=str, + # default=None, + # help="mask for flexible zero slicing / flexible zero slicingのマスク", + # ) parser.add_argument( "--flexible_zero_slicing_depth", type=int,