Merge branch 'dev' into gradual_latent_hires_fix

This commit is contained in:
Kohya S
2024-02-12 12:59:25 +09:00
62 changed files with 1428 additions and 993 deletions

View File

@@ -105,6 +105,10 @@ from library.original_unet import FlashAttentionFunction
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
@@ -140,12 +144,12 @@ USE_CUTOUTS = False
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
logger.info("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
print("Enable xformers for U-Net")
logger.info("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
@@ -153,7 +157,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
logger.info("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
@@ -169,7 +173,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs):
@@ -225,7 +229,7 @@ def replace_vae_attn_to_memory_efficient():
def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
logger.info("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers(self, hidden_states, **kwargs):
@@ -281,7 +285,7 @@ def replace_vae_attn_to_xformers():
def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa")
logger.info("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
@@ -695,7 +699,7 @@ class PipelineLike:
do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0")
logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance
@@ -777,11 +781,11 @@ class PipelineLike:
clip_text_input = prompt_tokens
if clip_text_input.shape[1] > self.tokenizer.model_max_length:
# TODO 75文字を超えたら警告を出す
print("trim text input", clip_text_input.shape)
logger.info(f"trim text input {clip_text_input.shape}")
clip_text_input = torch.cat(
[clip_text_input[:, : self.tokenizer.model_max_length - 1], clip_text_input[:, -1].unsqueeze(1)], dim=1
)
print("trimmed", clip_text_input.shape)
logger.info(f"trimmed {clip_text_input.shape}")
for i, clip_prompt in enumerate(clip_prompts):
if clip_prompt is not None: # clip_promptがあれば上書きする
@@ -1752,7 +1756,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length
pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length)
print(f"BREAK pad_len: {pad_len}")
logger.info(f"BREAK pad_len: {pad_len}")
for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0:
@@ -1782,7 +1786,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -2094,7 +2098,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
print(f"invalid count range: {count_range}")
logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
@@ -2164,7 +2168,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
# def load_clip_l14_336(dtype):
# print(f"loading CLIP: {CLIP_ID_L14_336}")
# logger.info(f"loading CLIP: {CLIP_ID_L14_336}")
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
# return text_encoder
@@ -2212,9 +2216,9 @@ def main(args):
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデルを読み込む
if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
@@ -2224,10 +2228,10 @@ def main(args):
use_stable_diffusion_format = os.path.isfile(args.ckpt)
if use_stable_diffusion_format:
print("load StableDiffusion checkpoint")
logger.info("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else:
print("load Diffusers pretrained models")
logger.info("load Diffusers pretrained models")
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = loading_pipe.text_encoder
vae = loading_pipe.vae
@@ -2250,21 +2254,21 @@ def main(args):
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)
print("additional VAE loaded")
logger.info("additional VAE loaded")
# # 置換するCLIPを読み込む
# if args.replace_clip_l14_336:
# text_encoder = load_clip_l14_336(dtype)
# print(f"large clip {CLIP_ID_L14_336} is loaded")
# logger.info(f"large clip {CLIP_ID_L14_336} is loaded")
if args.clip_guidance_scale > 0.0 or args.clip_image_guidance_scale:
print("prepare clip model")
logger.info("prepare clip model")
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_PATH, torch_dtype=dtype)
else:
clip_model = None
if args.vgg16_guidance_scale > 0.0:
print("prepare resnet model")
logger.info("prepare resnet model")
vgg16_model = torchvision.models.vgg16(torchvision.models.VGG16_Weights.IMAGENET1K_V1)
else:
vgg16_model = None
@@ -2276,7 +2280,7 @@ def main(args):
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む
print("loading tokenizer")
logger.info("loading tokenizer")
if use_stable_diffusion_format:
tokenizer = train_util.load_tokenizer(args)
@@ -2335,7 +2339,7 @@ def main(args):
self.sampler_noises = noises
def randn(self, shape, device=None, dtype=None, layout=None, generator=None):
# print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index)
# logger.info(f"replacing {shape} {len(self.sampler_noises)} {self.sampler_noise_index}")
if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises):
noise = self.sampler_noises[self.sampler_noise_index]
if shape != noise.shape:
@@ -2344,7 +2348,7 @@ def main(args):
noise = None
if noise == None:
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
self.sampler_noise_index += 1
@@ -2375,7 +2379,7 @@ def main(args):
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
print("set clip_sample to True")
logger.info("set clip_sample to True")
scheduler.config.clip_sample = True
# deviceを決定する
@@ -2432,7 +2436,7 @@ def main(args):
network_merge = 0
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
logger.info(f"import network module: {network_module}")
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
@@ -2450,7 +2454,7 @@ def main(args):
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
logger.info(f"load network weights from: {network_weight}")
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
@@ -2458,7 +2462,7 @@ def main(args):
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
logger.info(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
@@ -2468,20 +2472,20 @@ def main(args):
mergeable = network.is_mergeable()
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
logger.warning("network is not mergiable. ignore merge option.")
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
logger.info(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if network_pre_calc:
print("backup original weights")
logger.info("backup original weights")
network.backup_weights()
networks.append(network)
@@ -2495,7 +2499,7 @@ def main(args):
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
logger.info(f"import upscaler module {args.highres_fix_upscaler}")
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
@@ -2504,7 +2508,7 @@ def main(args):
key, value = net_arg.split("=")
us_kwargs[key] = value
print("create upscaler")
logger.info("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
@@ -2521,7 +2525,7 @@ def main(args):
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.opt_channels_last:
print(f"set optimizing: channels last")
logger.info(f"set optimizing: channels last")
text_encoder.to(memory_format=torch.channels_last)
vae.to(memory_format=torch.channels_last)
unet.to(memory_format=torch.channels_last)
@@ -2553,7 +2557,7 @@ def main(args):
args.vgg16_guidance_layer,
)
pipe.set_control_nets(control_nets)
print("pipeline is ready.")
logger.info("pipeline is ready.")
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
@@ -2619,7 +2623,7 @@ def main(args):
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
assert (
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
), f"token ids is not ordered"
@@ -2678,7 +2682,7 @@ def main(args):
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
logger.info(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
# if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
@@ -2703,7 +2707,7 @@ def main(args):
# promptを取得する
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
logger.info(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
@@ -2732,7 +2736,7 @@ def main(args):
for p in paths:
image = Image.open(p)
if image.mode != "RGB":
print(f"convert image to RGB from {image.mode}: {p}")
logger.info(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB")
images.append(image)
@@ -2748,24 +2752,24 @@ def main(args):
return resized
if args.image_path is not None:
print(f"load image for img2img: {args.image_path}")
logger.info(f"load image for img2img: {args.image_path}")
init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img")
logger.info(f"loaded {len(init_images)} images for img2img")
else:
init_images = None
if args.mask_path is not None:
print(f"load mask for inpainting: {args.mask_path}")
logger.info(f"load mask for inpainting: {args.mask_path}")
mask_images = load_images(args.mask_path)
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
print(f"loaded {len(mask_images)} mask images for inpainting")
logger.info(f"loaded {len(mask_images)} mask images for inpainting")
else:
mask_images = None
# promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
print("get prompts from images' meta data")
logger.info("get prompts from images' meta data")
for img in init_images:
if "prompt" in img.text:
prompt = img.text["prompt"]
@@ -2794,17 +2798,17 @@ def main(args):
h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None:
print(f"resize img2img source images to {w}*{h}")
logger.info(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h))
if mask_images is not None:
print(f"resize img2img mask images to {w}*{h}")
logger.info(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h))
regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
print("use mask as region")
logger.info("use mask as region")
size = None
for i, network in enumerate(networks):
@@ -2829,14 +2833,14 @@ def main(args):
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
logger.info(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
guide_images = []
for p in args.guide_image_path:
guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance")
logger.info(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0:
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
logger.info(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
guide_images = None
else:
guide_images = None
@@ -2862,7 +2866,7 @@ def main(args):
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF)
# shuffle prompt list
@@ -2878,7 +2882,7 @@ def main(args):
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage")
logger.info("process 1st stage")
batch_1st = []
for _, base, ext in batch:
width_1st = int(ext.width * args.highres_fix_scale + 0.5)
@@ -2904,7 +2908,7 @@ def main(args):
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
logger.info("process 2nd stage")
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
@@ -3061,7 +3065,7 @@ def main(args):
n.restore_weights()
for n in networks:
n.pre_calculation()
print("pre-calculation... done")
logger.info("pre-calculation... done")
images = pipe(
prompts,
@@ -3130,7 +3134,7 @@ def main(args):
cv2.waitKey()
cv2.destroyAllWindows()
except ImportError:
print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
logger.info("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
return images
@@ -3143,7 +3147,8 @@ def main(args):
# interactive
valid = False
while not valid:
print("\nType prompt:")
logger.info("")
logger.info("Type prompt:")
try:
raw_prompt = input()
except EOFError:
@@ -3194,38 +3199,38 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
print(f"width: {width}")
logger.info(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
print(f"height: {height}")
logger.info(f"height: {height}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
steps = max(1, min(1000, int(m.group(1))))
print(f"steps: {steps}")
logger.info(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
print(f"seeds: {seeds}")
logger.info(f"seeds: {seeds}")
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
print(f"scale: {scale}")
logger.info(f"scale: {scale}")
continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
@@ -3234,25 +3239,25 @@ def main(args):
negative_scale = None
else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
logger.info(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
print(f"strength: {strength}")
logger.info(f"strength: {strength}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}")
logger.info(f"negative prompt: {negative_prompt}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt
clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}")
logger.info(f"clip prompt: {clip_prompt}")
continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
@@ -3260,42 +3265,42 @@ def main(args):
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}")
logger.info(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
logger.info(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
logger.info(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
logger.info(f"deep shrink ratio: {ds_ratio}")
continue
# Gradual Latent
@@ -3341,8 +3346,8 @@ def main(args):
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
logger.info(f"Exception in parsing / 解析エラー: {parg}")
logger.info(ex)
# override Deep Shrink
if ds_depth_1 is not None:
@@ -3385,7 +3390,7 @@ def main(args):
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
else:
print("predefined seeds are exhausted")
logger.info("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seed = iter_seed
@@ -3395,7 +3400,7 @@ def main(args):
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
if args.interactive:
print(f"seed: {seed}")
logger.info(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
@@ -3411,7 +3416,7 @@ def main(args):
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
logger.info(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
@@ -3427,9 +3432,9 @@ def main(args):
guide_image = guide_images[global_step % len(guide_images)]
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
if prev_image is None:
print("Generate 1st image without guide image.")
logger.info("Generate 1st image without guide image.")
else:
print("Use previous image as guide image.")
logger.info("Use previous image as guide image.")
guide_image = prev_image
if regional_network:
@@ -3473,7 +3478,7 @@ def main(args):
process_batch(batch_data, highres_fix)
batch_data.clear()
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser: