mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix network_merge, add regional mask as color code
This commit is contained in:
@@ -2372,7 +2372,7 @@ def main(args):
|
|||||||
elif args.network_merge_n_models:
|
elif args.network_merge_n_models:
|
||||||
network_merge = args.network_merge_n_models
|
network_merge = args.network_merge_n_models
|
||||||
else:
|
else:
|
||||||
network_merge = None
|
network_merge = 0
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
print("import network module:", network_module)
|
||||||
@@ -2724,8 +2724,17 @@ def main(args):
|
|||||||
|
|
||||||
size = None
|
size = None
|
||||||
for i, network in enumerate(networks):
|
for i, network in enumerate(networks):
|
||||||
if i < 3:
|
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
|
||||||
np_mask = np.array(mask_images[0])
|
np_mask = np.array(mask_images[0])
|
||||||
|
|
||||||
|
if args.network_regional_mask_max_color_codes:
|
||||||
|
# カラーコードでマスクを指定する
|
||||||
|
ch0 = (i + 1) & 1
|
||||||
|
ch1 = ((i + 1) >> 1) & 1
|
||||||
|
ch2 = ((i + 1) >> 2) & 1
|
||||||
|
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
||||||
|
np_mask = np_mask.astype(np.uint8) * 255
|
||||||
|
else:
|
||||||
np_mask = np_mask[:, :, i]
|
np_mask = np_mask[:, :, i]
|
||||||
size = np_mask.shape
|
size = np_mask.shape
|
||||||
else:
|
else:
|
||||||
@@ -3386,6 +3395,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--network_regional_mask_max_color_codes",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--textual_inversion_embeddings",
|
"--textual_inversion_embeddings",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -1543,7 +1543,7 @@ def main(args):
|
|||||||
elif args.network_merge_n_models:
|
elif args.network_merge_n_models:
|
||||||
network_merge = args.network_merge_n_models
|
network_merge = args.network_merge_n_models
|
||||||
else:
|
else:
|
||||||
network_merge = None
|
network_merge = 0
|
||||||
print(f"network_merge: {network_merge}")
|
print(f"network_merge: {network_merge}")
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
@@ -1877,8 +1877,17 @@ def main(args):
|
|||||||
|
|
||||||
size = None
|
size = None
|
||||||
for i, network in enumerate(networks):
|
for i, network in enumerate(networks):
|
||||||
if i < 3:
|
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
|
||||||
np_mask = np.array(mask_images[0])
|
np_mask = np.array(mask_images[0])
|
||||||
|
|
||||||
|
if args.network_regional_mask_max_color_codes:
|
||||||
|
# カラーコードでマスクを指定する
|
||||||
|
ch0 = (i + 1) & 1
|
||||||
|
ch1 = ((i + 1) >> 1) & 1
|
||||||
|
ch2 = ((i + 1) >> 2) & 1
|
||||||
|
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
||||||
|
np_mask = np_mask.astype(np.uint8) * 255
|
||||||
|
else:
|
||||||
np_mask = np_mask[:, :, i]
|
np_mask = np_mask[:, :, i]
|
||||||
size = np_mask.shape
|
size = np_mask.shape
|
||||||
else:
|
else:
|
||||||
@@ -2635,6 +2644,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--network_regional_mask_max_color_codes",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--textual_inversion_embeddings",
|
"--textual_inversion_embeddings",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user