fix network_merge, add regional mask as color code

This commit is contained in:
Kohya S
2023-10-09 23:07:14 +09:00
parent 23ae358e0f
commit 3e81bd6b67
2 changed files with 36 additions and 6 deletions

View File

@@ -2372,7 +2372,7 @@ def main(args):
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = None
network_merge = 0
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
@@ -2724,9 +2724,18 @@ def main(args):
size = None
for i, network in enumerate(networks):
if i < 3:
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
@@ -3386,6 +3395,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数デフォルトはNoneでチャンネルごとのマスク",
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,