mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
impl flexible zero slicing
This commit is contained in:
50
gen_img.py
50
gen_img.py
@@ -1,5 +1,6 @@
|
||||
import itertools
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
import importlib
|
||||
@@ -20,7 +21,7 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
from library.device_utils import init_ipex, clean_memory, get_preferred_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -338,7 +339,7 @@ class PipelineLike:
|
||||
self.clip_vision_model: CLIPVisionModelWithProjection = None
|
||||
self.clip_vision_processor: CLIPImageProcessor = None
|
||||
self.clip_vision_strength = 0.0
|
||||
|
||||
|
||||
# Textual Inversion
|
||||
self.token_replacements_list = []
|
||||
for _ in range(len(self.text_encoders)):
|
||||
@@ -1926,6 +1927,18 @@ def main(args):
|
||||
)
|
||||
pipe.set_gradual_latent(gradual_latent)
|
||||
|
||||
# Flexible Zero Slicing
|
||||
if args.flexible_zero_slicing_mask:
|
||||
# mask 画像は背景 255、zero にする部分 0 とする
|
||||
print(f"loading Flexible Zero Slicing mask")
|
||||
fz_mask = Image.open(args.flexible_zero_slicing_mask).convert("RGB")
|
||||
fz_mask = np.array(fz_mask).astype(np.float32) / 255.0
|
||||
fz_mask = fz_mask[:, :, 0]
|
||||
fz_mask = torch.from_numpy(fz_mask).to(dtype).to(device)
|
||||
|
||||
# only for sdxl
|
||||
unet.set_flexible_zero_slicing(fz_mask, args.flexible_zero_slicing_depth, args.flexible_zero_slicing_timesteps)
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds1 = []
|
||||
@@ -2146,10 +2159,17 @@ def main(args):
|
||||
|
||||
if args.network_regional_mask_max_color_codes:
|
||||
# カラーコードでマスクを指定する
|
||||
ch0 = (i + 1) & 1
|
||||
ch1 = ((i + 1) >> 1) & 1
|
||||
ch2 = ((i + 1) >> 2) & 1
|
||||
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
||||
# 0-7: RGB 3bitで8色, 0/255
|
||||
# 8-15: RGB 3bitで8色, 0/127
|
||||
code = (i % 7) + 1
|
||||
r = code & 1
|
||||
g = (code & 2) >> 1
|
||||
b = (code & 4) >> 2
|
||||
if i < 7:
|
||||
color = (r * 255, g * 255, b * 255)
|
||||
else:
|
||||
color = (r * 127, g * 127, b * 127)
|
||||
np_mask = np.all(np_mask == color, axis=2)
|
||||
np_mask = np_mask.astype(np.uint8) * 255
|
||||
else:
|
||||
np_mask = np_mask[:, :, i]
|
||||
@@ -3312,6 +3332,24 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--flexible_zero_slicing_mask",
|
||||
type=str,
|
||||
default=None,
|
||||
help="mask for flexible zero slicing / flexible zero slicingのマスク",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flexible_zero_slicing_depth",
|
||||
type=int,
|
||||
default=None,
|
||||
help="depth for flexible zero slicing / flexible zero slicingのdepth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flexible_zero_slicing_timesteps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="timesteps for flexible zero slicing / flexible zero slicingのtimesteps",
|
||||
)
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user