mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
make mask for flexible zero slicing from attncouple mask
This commit is contained in:
51
gen_img.py
51
gen_img.py
@@ -1927,18 +1927,6 @@ def main(args):
|
||||
)
|
||||
pipe.set_gradual_latent(gradual_latent)
|
||||
|
||||
# Flexible Zero Slicing
|
||||
if args.flexible_zero_slicing_mask:
|
||||
# mask 画像は背景 255、zero にする部分 0 とする
|
||||
print(f"loading Flexible Zero Slicing mask")
|
||||
fz_mask = Image.open(args.flexible_zero_slicing_mask).convert("RGB")
|
||||
fz_mask = np.array(fz_mask).astype(np.float32) / 255.0
|
||||
fz_mask = fz_mask[:, :, 0]
|
||||
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
|
||||
|
||||
# only for sdxl
|
||||
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds1 = []
|
||||
@@ -2131,6 +2119,33 @@ def main(args):
|
||||
l.extend([im] * args.images_per_prompt)
|
||||
mask_images = l
|
||||
|
||||
# Flexible Zero Slicing
|
||||
if args.flexible_zero_slicing_depth is not None:
|
||||
# CV2 が必要
|
||||
import cv2
|
||||
|
||||
# mask 画像は背景 255、zero にする部分 0 とする
|
||||
np_mask = np.array(mask_images[0].convert("RGB"))
|
||||
fz_mask = np.full(np_mask.shape, 255, dtype=np.uint8)
|
||||
|
||||
# 各チャンネルに対して処理
|
||||
for i in range(3):
|
||||
# チャンネルを抽出
|
||||
channel = np_mask[:, :, i]
|
||||
|
||||
# 輪郭を検出
|
||||
contours, _ = cv2.findContours(channel, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# 輪郭を新しい配列に描画
|
||||
cv2.drawContours(fz_mask, contours, -1, (0, 0, 0), 1)
|
||||
|
||||
fz_mask = fz_mask.astype(np.float32) / 255.0
|
||||
fz_mask = fz_mask[:, :, 0]
|
||||
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
|
||||
|
||||
# only for sdxl
|
||||
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
|
||||
|
||||
# 画像サイズにオプション指定があるときはリサイズする
|
||||
if args.W is not None and args.H is not None:
|
||||
# highres fix を考慮に入れる
|
||||
@@ -3332,12 +3347,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--flexible_zero_slicing_mask",
|
||||
type=str,
|
||||
default=None,
|
||||
help="mask for flexible zero slicing / flexible zero slicingのマスク",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--flexible_zero_slicing_mask",
|
||||
# type=str,
|
||||
# default=None,
|
||||
# help="mask for flexible zero slicing / flexible zero slicingのマスク",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--flexible_zero_slicing_depth",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user