mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Replaced print with logger
This commit is contained in:
194
gen_img.py
194
gen_img.py
@@ -61,6 +61,12 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
|||||||
from library.original_unet import FlashAttentionFunction
|
from library.original_unet import FlashAttentionFunction
|
||||||
from networks.control_net_lllite import ControlNetLLLite
|
from networks.control_net_lllite import ControlNetLLLite
|
||||||
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
|
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
|
||||||
|
from library.utils import setup_logging, add_logging_arguments
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# scheduler:
|
# scheduler:
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
@@ -82,12 +88,12 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
|||||||
|
|
||||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
|
||||||
if mem_eff_attn:
|
if mem_eff_attn:
|
||||||
print("Enable memory efficient attention for U-Net")
|
logger.info("Enable memory efficient attention for U-Net")
|
||||||
|
|
||||||
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
|
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
|
||||||
unet.set_use_memory_efficient_attention(False, True)
|
unet.set_use_memory_efficient_attention(False, True)
|
||||||
elif xformers:
|
elif xformers:
|
||||||
print("Enable xformers for U-Net")
|
logger.info("Enable xformers for U-Net")
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -95,7 +101,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
|||||||
|
|
||||||
unet.set_use_memory_efficient_attention(True, False)
|
unet.set_use_memory_efficient_attention(True, False)
|
||||||
elif sdpa:
|
elif sdpa:
|
||||||
print("Enable SDPA for U-Net")
|
logger.info("Enable SDPA for U-Net")
|
||||||
unet.set_use_memory_efficient_attention(False, False)
|
unet.set_use_memory_efficient_attention(False, False)
|
||||||
unet.set_use_sdpa(True)
|
unet.set_use_sdpa(True)
|
||||||
|
|
||||||
@@ -112,7 +118,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_memory_efficient():
|
def replace_vae_attn_to_memory_efficient():
|
||||||
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
|
||||||
flash_func = FlashAttentionFunction
|
flash_func = FlashAttentionFunction
|
||||||
|
|
||||||
def forward_flash_attn(self, hidden_states, **kwargs):
|
def forward_flash_attn(self, hidden_states, **kwargs):
|
||||||
@@ -168,7 +174,7 @@ def replace_vae_attn_to_memory_efficient():
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_xformers():
|
def replace_vae_attn_to_xformers():
|
||||||
print("VAE: Attention.forward has been replaced to xformers")
|
logger.info("VAE: Attention.forward has been replaced to xformers")
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
|
|
||||||
def forward_xformers(self, hidden_states, **kwargs):
|
def forward_xformers(self, hidden_states, **kwargs):
|
||||||
@@ -224,7 +230,7 @@ def replace_vae_attn_to_xformers():
|
|||||||
|
|
||||||
|
|
||||||
def replace_vae_attn_to_sdpa():
|
def replace_vae_attn_to_sdpa():
|
||||||
print("VAE: Attention.forward has been replaced to sdpa")
|
logger.info("VAE: Attention.forward has been replaced to sdpa")
|
||||||
|
|
||||||
def forward_sdpa(self, hidden_states, **kwargs):
|
def forward_sdpa(self, hidden_states, **kwargs):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -386,10 +392,10 @@ class PipelineLike:
|
|||||||
|
|
||||||
def set_gradual_latent(self, gradual_latent):
|
def set_gradual_latent(self, gradual_latent):
|
||||||
if gradual_latent is None:
|
if gradual_latent is None:
|
||||||
print("gradual_latent is disabled")
|
logger.info("gradual_latent is disabled")
|
||||||
self.gradual_latent = None
|
self.gradual_latent = None
|
||||||
else:
|
else:
|
||||||
print(f"gradual_latent is enabled: {gradual_latent}")
|
logger.info(f"gradual_latent is enabled: {gradual_latent}")
|
||||||
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -467,7 +473,7 @@ class PipelineLike:
|
|||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
if not do_classifier_free_guidance and negative_scale is not None:
|
if not do_classifier_free_guidance and negative_scale is not None:
|
||||||
print(f"negative_scale is ignored if guidance scalle <= 1.0")
|
logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0")
|
||||||
negative_scale = None
|
negative_scale = None
|
||||||
|
|
||||||
# get unconditional embeddings for classifier free guidance
|
# get unconditional embeddings for classifier free guidance
|
||||||
@@ -576,7 +582,7 @@ class PipelineLike:
|
|||||||
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
|
||||||
|
|
||||||
if init_image is not None and self.clip_vision_model is not None:
|
if init_image is not None and self.clip_vision_model is not None:
|
||||||
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
|
logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
|
||||||
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
|
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
|
||||||
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
|
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
|
||||||
|
|
||||||
@@ -742,8 +748,8 @@ class PipelineLike:
|
|||||||
enable_gradual_latent = False
|
enable_gradual_latent = False
|
||||||
if self.gradual_latent:
|
if self.gradual_latent:
|
||||||
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
||||||
print("gradual_latent is not supported for this scheduler. Ignoring.")
|
logger.warning("gradual_latent is not supported for this scheduler. Ignoring.")
|
||||||
print(self.scheduler.__class__.__name__)
|
logger.warning(f"{self.scheduler.__class__.__name__}")
|
||||||
else:
|
else:
|
||||||
enable_gradual_latent = True
|
enable_gradual_latent = True
|
||||||
step_elapsed = 1000
|
step_elapsed = 1000
|
||||||
@@ -792,7 +798,7 @@ class PipelineLike:
|
|||||||
if not enabled or ratio >= 1.0:
|
if not enabled or ratio >= 1.0:
|
||||||
continue
|
continue
|
||||||
if ratio < i / len(timesteps):
|
if ratio < i / len(timesteps):
|
||||||
print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
|
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
|
||||||
control_net.set_cond_image(None)
|
control_net.set_cond_image(None)
|
||||||
each_control_net_enabled[j] = False
|
each_control_net_enabled[j] = False
|
||||||
|
|
||||||
@@ -1013,7 +1019,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
|
|||||||
if word.strip() == "BREAK":
|
if word.strip() == "BREAK":
|
||||||
# pad until next multiple of tokenizer's max token length
|
# pad until next multiple of tokenizer's max token length
|
||||||
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
|
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
|
||||||
print(f"BREAK pad_len: {pad_len}")
|
logger.info(f"BREAK pad_len: {pad_len}")
|
||||||
for i in range(pad_len):
|
for i in range(pad_len):
|
||||||
# v2のときEOSをつけるべきかどうかわからないぜ
|
# v2のときEOSをつけるべきかどうかわからないぜ
|
||||||
# if i == 0:
|
# if i == 0:
|
||||||
@@ -1043,7 +1049,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
|
|||||||
tokens.append(text_token)
|
tokens.append(text_token)
|
||||||
weights.append(text_weight)
|
weights.append(text_weight)
|
||||||
if truncated:
|
if truncated:
|
||||||
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||||
return tokens, weights
|
return tokens, weights
|
||||||
|
|
||||||
|
|
||||||
@@ -1344,7 +1350,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
|
|||||||
elif len(count_range) == 2:
|
elif len(count_range) == 2:
|
||||||
count_range = [int(count_range[0]), int(count_range[1])]
|
count_range = [int(count_range[0]), int(count_range[1])]
|
||||||
else:
|
else:
|
||||||
print(f"invalid count range: {count_range}")
|
logger.warning(f"invalid count range: {count_range}")
|
||||||
count_range = [1, 1]
|
count_range = [1, 1]
|
||||||
if count_range[0] > count_range[1]:
|
if count_range[0] > count_range[1]:
|
||||||
count_range = [count_range[1], count_range[0]]
|
count_range = [count_range[1], count_range[0]]
|
||||||
@@ -1488,9 +1494,9 @@ def main(args):
|
|||||||
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||||
|
|
||||||
if args.v_parameterization and not args.v2:
|
if args.v_parameterization and not args.v2:
|
||||||
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||||
if args.v2 and args.clip_skip is not None:
|
if args.v2 and args.clip_skip is not None:
|
||||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
|
if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
|
||||||
@@ -1510,7 +1516,7 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
# if `text_encoder_2` subdirectory exists, sdxl
|
# if `text_encoder_2` subdirectory exists, sdxl
|
||||||
is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2"))
|
is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2"))
|
||||||
print(f"SDXL: {is_sdxl}")
|
logger.info(f"SDXL: {is_sdxl}")
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
if args.clip_skip is None:
|
if args.clip_skip is None:
|
||||||
@@ -1526,10 +1532,10 @@ def main(args):
|
|||||||
args.clip_skip = 2 if args.v2 else 1
|
args.clip_skip = 2 if args.v2 else 1
|
||||||
|
|
||||||
if use_stable_diffusion_format:
|
if use_stable_diffusion_format:
|
||||||
print("load StableDiffusion checkpoint")
|
logger.info("load StableDiffusion checkpoint")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
||||||
else:
|
else:
|
||||||
print("load Diffusers pretrained models")
|
logger.info("load Diffusers pretrained models")
|
||||||
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
||||||
text_encoder = loading_pipe.text_encoder
|
text_encoder = loading_pipe.text_encoder
|
||||||
vae = loading_pipe.vae
|
vae = loading_pipe.vae
|
||||||
@@ -1553,7 +1559,7 @@ def main(args):
|
|||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
vae = model_util.load_vae(args.vae, dtype)
|
vae = model_util.load_vae(args.vae, dtype)
|
||||||
print("additional VAE loaded")
|
logger.info("additional VAE loaded")
|
||||||
|
|
||||||
# xformers、Hypernetwork対応
|
# xformers、Hypernetwork対応
|
||||||
if not args.diffusers_xformers:
|
if not args.diffusers_xformers:
|
||||||
@@ -1562,7 +1568,7 @@ def main(args):
|
|||||||
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
|
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
|
||||||
|
|
||||||
# tokenizerを読み込む
|
# tokenizerを読み込む
|
||||||
print("loading tokenizer")
|
logger.info("loading tokenizer")
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||||
tokenizers = [tokenizer1, tokenizer2]
|
tokenizers = [tokenizer1, tokenizer2]
|
||||||
@@ -1654,7 +1660,7 @@ def main(args):
|
|||||||
noise = None
|
noise = None
|
||||||
|
|
||||||
if noise == None:
|
if noise == None:
|
||||||
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
|
logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
|
||||||
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
|
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
|
||||||
|
|
||||||
self.sampler_noise_index += 1
|
self.sampler_noise_index += 1
|
||||||
@@ -1715,7 +1721,7 @@ def main(args):
|
|||||||
|
|
||||||
vae_dtype = dtype
|
vae_dtype = dtype
|
||||||
if args.no_half_vae:
|
if args.no_half_vae:
|
||||||
print("set vae_dtype to float32")
|
logger.info("set vae_dtype to float32")
|
||||||
vae_dtype = torch.float32
|
vae_dtype = torch.float32
|
||||||
vae.to(vae_dtype).to(device)
|
vae.to(vae_dtype).to(device)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
@@ -1739,10 +1745,10 @@ def main(args):
|
|||||||
network_merge = args.network_merge_n_models
|
network_merge = args.network_merge_n_models
|
||||||
else:
|
else:
|
||||||
network_merge = 0
|
network_merge = 0
|
||||||
print(f"network_merge: {network_merge}")
|
logger.info(f"network_merge: {network_merge}")
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
logger.info("import network module: {network_module}")
|
||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
@@ -1760,7 +1766,7 @@ def main(args):
|
|||||||
raise ValueError("No weight. Weight is required.")
|
raise ValueError("No weight. Weight is required.")
|
||||||
|
|
||||||
network_weight = args.network_weights[i]
|
network_weight = args.network_weights[i]
|
||||||
print("load network weights from:", network_weight)
|
logger.info(f"load network weights from: {network_weight}")
|
||||||
|
|
||||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
@@ -1768,7 +1774,7 @@ def main(args):
|
|||||||
with safe_open(network_weight, framework="pt") as f:
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
print(f"metadata for: {network_weight}: {metadata}")
|
logger.info(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network, weights_sd = imported_module.create_network_from_weights(
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs
|
network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs
|
||||||
@@ -1778,20 +1784,20 @@ def main(args):
|
|||||||
|
|
||||||
mergeable = network.is_mergeable()
|
mergeable = network.is_mergeable()
|
||||||
if network_merge and not mergeable:
|
if network_merge and not mergeable:
|
||||||
print("network is not mergiable. ignore merge option.")
|
logger.warning("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
if not mergeable or i >= network_merge:
|
if not mergeable or i >= network_merge:
|
||||||
# not merging
|
# not merging
|
||||||
network.apply_to(text_encoders, unet)
|
network.apply_to(text_encoders, unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
logger.info(f"weights are loaded: {info}")
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
network.to(dtype).to(device)
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
if network_pre_calc:
|
if network_pre_calc:
|
||||||
print("backup original weights")
|
logger.info("backup original weights")
|
||||||
network.backup_weights()
|
network.backup_weights()
|
||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
@@ -1805,7 +1811,7 @@ def main(args):
|
|||||||
# upscalerの指定があれば取得する
|
# upscalerの指定があれば取得する
|
||||||
upscaler = None
|
upscaler = None
|
||||||
if args.highres_fix_upscaler:
|
if args.highres_fix_upscaler:
|
||||||
print("import upscaler module:", args.highres_fix_upscaler)
|
logger.info("import upscaler module: {args.highres_fix_upscaler}")
|
||||||
imported_module = importlib.import_module(args.highres_fix_upscaler)
|
imported_module = importlib.import_module(args.highres_fix_upscaler)
|
||||||
|
|
||||||
us_kwargs = {}
|
us_kwargs = {}
|
||||||
@@ -1814,7 +1820,7 @@ def main(args):
|
|||||||
key, value = net_arg.split("=")
|
key, value = net_arg.split("=")
|
||||||
us_kwargs[key] = value
|
us_kwargs[key] = value
|
||||||
|
|
||||||
print("create upscaler")
|
logger.info("create upscaler")
|
||||||
upscaler = imported_module.create_upscaler(**us_kwargs)
|
upscaler = imported_module.create_upscaler(**us_kwargs)
|
||||||
upscaler.to(dtype).to(device)
|
upscaler.to(dtype).to(device)
|
||||||
|
|
||||||
@@ -1833,7 +1839,7 @@ def main(args):
|
|||||||
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
|
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
|
||||||
if args.control_net_lllite_models:
|
if args.control_net_lllite_models:
|
||||||
for i, model_file in enumerate(args.control_net_lllite_models):
|
for i, model_file in enumerate(args.control_net_lllite_models):
|
||||||
print(f"loading ControlNet-LLLite: {model_file}")
|
logger.info(f"loading ControlNet-LLLite: {model_file}")
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
@@ -1867,7 +1873,7 @@ def main(args):
|
|||||||
), "ControlNet and ControlNet-LLLite cannot be used at the same time"
|
), "ControlNet and ControlNet-LLLite cannot be used at the same time"
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
print(f"set optimizing: channels last")
|
logger.info(f"set optimizing: channels last")
|
||||||
for text_encoder in text_encoders:
|
for text_encoder in text_encoders:
|
||||||
text_encoder.to(memory_format=torch.channels_last)
|
text_encoder.to(memory_format=torch.channels_last)
|
||||||
vae.to(memory_format=torch.channels_last)
|
vae.to(memory_format=torch.channels_last)
|
||||||
@@ -1894,7 +1900,7 @@ def main(args):
|
|||||||
)
|
)
|
||||||
pipe.set_control_nets(control_nets)
|
pipe.set_control_nets(control_nets)
|
||||||
pipe.set_control_net_lllites(control_net_lllites)
|
pipe.set_control_net_lllites(control_net_lllites)
|
||||||
print("pipeline is ready.")
|
logger.info("pipeline is ready.")
|
||||||
|
|
||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
@@ -1965,7 +1971,7 @@ def main(args):
|
|||||||
|
|
||||||
token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings)
|
token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings)
|
||||||
token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None
|
token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None
|
||||||
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
|
logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
|
||||||
assert (
|
assert (
|
||||||
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
|
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
|
||||||
), f"token ids1 is not ordered"
|
), f"token ids1 is not ordered"
|
||||||
@@ -2002,7 +2008,7 @@ def main(args):
|
|||||||
# promptを取得する
|
# promptを取得する
|
||||||
prompt_list = None
|
prompt_list = None
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
logger.info(f"reading prompts from {args.from_file}")
|
||||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||||
prompt_list = f.read().splitlines()
|
prompt_list = f.read().splitlines()
|
||||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
|
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
|
||||||
@@ -2019,7 +2025,7 @@ def main(args):
|
|||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
print(f"reading prompts from module: {args.from_module}")
|
logger.info(f"reading prompts from module: {args.from_module}")
|
||||||
prompt_module = load_module_from_path("prompt_module", args.from_module)
|
prompt_module = load_module_from_path("prompt_module", args.from_module)
|
||||||
|
|
||||||
prompter = prompt_module.get_prompter(args, pipe, networks)
|
prompter = prompt_module.get_prompter(args, pipe, networks)
|
||||||
@@ -2050,7 +2056,7 @@ def main(args):
|
|||||||
for p in paths:
|
for p in paths:
|
||||||
image = Image.open(p)
|
image = Image.open(p)
|
||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
print(f"convert image to RGB from {image.mode}: {p}")
|
logger.info(f"convert image to RGB from {image.mode}: {p}")
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
images.append(image)
|
images.append(image)
|
||||||
|
|
||||||
@@ -2066,14 +2072,14 @@ def main(args):
|
|||||||
return resized
|
return resized
|
||||||
|
|
||||||
if args.image_path is not None:
|
if args.image_path is not None:
|
||||||
print(f"load image for img2img: {args.image_path}")
|
logger.info(f"load image for img2img: {args.image_path}")
|
||||||
init_images = load_images(args.image_path)
|
init_images = load_images(args.image_path)
|
||||||
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
|
||||||
print(f"loaded {len(init_images)} images for img2img")
|
logger.info(f"loaded {len(init_images)} images for img2img")
|
||||||
|
|
||||||
# CLIP Vision
|
# CLIP Vision
|
||||||
if args.clip_vision_strength is not None:
|
if args.clip_vision_strength is not None:
|
||||||
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
|
logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
|
||||||
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
|
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
|
||||||
vision_model.to(device, dtype)
|
vision_model.to(device, dtype)
|
||||||
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
|
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
|
||||||
@@ -2081,22 +2087,22 @@ def main(args):
|
|||||||
pipe.clip_vision_model = vision_model
|
pipe.clip_vision_model = vision_model
|
||||||
pipe.clip_vision_processor = processor
|
pipe.clip_vision_processor = processor
|
||||||
pipe.clip_vision_strength = args.clip_vision_strength
|
pipe.clip_vision_strength = args.clip_vision_strength
|
||||||
print(f"CLIP Vision model loaded.")
|
logger.info(f"CLIP Vision model loaded.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
init_images = None
|
init_images = None
|
||||||
|
|
||||||
if args.mask_path is not None:
|
if args.mask_path is not None:
|
||||||
print(f"load mask for inpainting: {args.mask_path}")
|
logger.info(f"load mask for inpainting: {args.mask_path}")
|
||||||
mask_images = load_images(args.mask_path)
|
mask_images = load_images(args.mask_path)
|
||||||
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
|
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
|
||||||
print(f"loaded {len(mask_images)} mask images for inpainting")
|
logger.info(f"loaded {len(mask_images)} mask images for inpainting")
|
||||||
else:
|
else:
|
||||||
mask_images = None
|
mask_images = None
|
||||||
|
|
||||||
# promptがないとき、画像のPngInfoから取得する
|
# promptがないとき、画像のPngInfoから取得する
|
||||||
if init_images is not None and prompter is None and not args.interactive:
|
if init_images is not None and prompter is None and not args.interactive:
|
||||||
print("get prompts from images' metadata")
|
logger.info("get prompts from images' metadata")
|
||||||
prompt_list = []
|
prompt_list = []
|
||||||
for img in init_images:
|
for img in init_images:
|
||||||
if "prompt" in img.text:
|
if "prompt" in img.text:
|
||||||
@@ -2127,17 +2133,17 @@ def main(args):
|
|||||||
h = int(h * args.highres_fix_scale + 0.5)
|
h = int(h * args.highres_fix_scale + 0.5)
|
||||||
|
|
||||||
if init_images is not None:
|
if init_images is not None:
|
||||||
print(f"resize img2img source images to {w}*{h}")
|
logger.info(f"resize img2img source images to {w}*{h}")
|
||||||
init_images = resize_images(init_images, (w, h))
|
init_images = resize_images(init_images, (w, h))
|
||||||
if mask_images is not None:
|
if mask_images is not None:
|
||||||
print(f"resize img2img mask images to {w}*{h}")
|
logger.info(f"resize img2img mask images to {w}*{h}")
|
||||||
mask_images = resize_images(mask_images, (w, h))
|
mask_images = resize_images(mask_images, (w, h))
|
||||||
|
|
||||||
regional_network = False
|
regional_network = False
|
||||||
if networks and mask_images:
|
if networks and mask_images:
|
||||||
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
||||||
regional_network = True
|
regional_network = True
|
||||||
print("use mask as region")
|
logger.info("use mask as region")
|
||||||
|
|
||||||
size = None
|
size = None
|
||||||
for i, network in enumerate(networks):
|
for i, network in enumerate(networks):
|
||||||
@@ -2162,14 +2168,14 @@ def main(args):
|
|||||||
|
|
||||||
prev_image = None # for VGG16 guided
|
prev_image = None # for VGG16 guided
|
||||||
if args.guide_image_path is not None:
|
if args.guide_image_path is not None:
|
||||||
print(f"load image for ControlNet guidance: {args.guide_image_path}")
|
logger.info(f"load image for ControlNet guidance: {args.guide_image_path}")
|
||||||
guide_images = []
|
guide_images = []
|
||||||
for p in args.guide_image_path:
|
for p in args.guide_image_path:
|
||||||
guide_images.extend(load_images(p))
|
guide_images.extend(load_images(p))
|
||||||
|
|
||||||
print(f"loaded {len(guide_images)} guide images for guidance")
|
logger.info(f"loaded {len(guide_images)} guide images for guidance")
|
||||||
if len(guide_images) == 0:
|
if len(guide_images) == 0:
|
||||||
print(
|
logger.warning(
|
||||||
f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}"
|
f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}"
|
||||||
)
|
)
|
||||||
guide_images = None
|
guide_images = None
|
||||||
@@ -2200,7 +2206,7 @@ def main(args):
|
|||||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||||
|
|
||||||
for gen_iter in range(args.n_iter):
|
for gen_iter in range(args.n_iter):
|
||||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||||
if args.iter_same_seed:
|
if args.iter_same_seed:
|
||||||
iter_seed = seed_random.randint(0, 2**32 - 1)
|
iter_seed = seed_random.randint(0, 2**32 - 1)
|
||||||
else:
|
else:
|
||||||
@@ -2219,7 +2225,7 @@ def main(args):
|
|||||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||||
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
|
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
|
||||||
|
|
||||||
print("process 1st stage")
|
logger.info("process 1st stage")
|
||||||
batch_1st = []
|
batch_1st = []
|
||||||
for _, base, ext in batch:
|
for _, base, ext in batch:
|
||||||
|
|
||||||
@@ -2264,7 +2270,7 @@ def main(args):
|
|||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
|
|
||||||
# 2nd stageのバッチを作成して以下処理する
|
# 2nd stageのバッチを作成して以下処理する
|
||||||
print("process 2nd stage")
|
logger.info("process 2nd stage")
|
||||||
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
|
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
|
||||||
|
|
||||||
if upscaler:
|
if upscaler:
|
||||||
@@ -2437,7 +2443,7 @@ def main(args):
|
|||||||
n.restore_weights()
|
n.restore_weights()
|
||||||
for n in networks:
|
for n in networks:
|
||||||
n.pre_calculation()
|
n.pre_calculation()
|
||||||
print("pre-calculation... done")
|
logger.info("pre-calculation... done")
|
||||||
|
|
||||||
images = pipe(
|
images = pipe(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -2520,7 +2526,7 @@ def main(args):
|
|||||||
cv2.waitKey()
|
cv2.waitKey()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(
|
logger.warning(
|
||||||
"opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません"
|
"opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2535,7 +2541,7 @@ def main(args):
|
|||||||
# interactive
|
# interactive
|
||||||
valid = False
|
valid = False
|
||||||
while not valid:
|
while not valid:
|
||||||
print("\nType prompt:")
|
logger.info("\nType prompt:")
|
||||||
try:
|
try:
|
||||||
raw_prompt = input()
|
raw_prompt = input()
|
||||||
except EOFError:
|
except EOFError:
|
||||||
@@ -2595,74 +2601,74 @@ def main(args):
|
|||||||
prompt_args = raw_prompt.strip().split(" --")
|
prompt_args = raw_prompt.strip().split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
length = len(prompter) if hasattr(prompter, "__len__") else 0
|
length = len(prompter) if hasattr(prompter, "__len__") else 0
|
||||||
print(f"prompt {prompt_index+1}/{length}: {prompt}")
|
logger.info(f"prompt {prompt_index+1}/{length}: {prompt}")
|
||||||
|
|
||||||
for parg in prompt_args[1:]:
|
for parg in prompt_args[1:]:
|
||||||
try:
|
try:
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
width = int(m.group(1))
|
width = int(m.group(1))
|
||||||
print(f"width: {width}")
|
logger.info(f"width: {width}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
height = int(m.group(1))
|
height = int(m.group(1))
|
||||||
print(f"height: {height}")
|
logger.info(f"height: {height}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_width = int(m.group(1))
|
original_width = int(m.group(1))
|
||||||
print(f"original width: {original_width}")
|
logger.info(f"original width: {original_width}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_height = int(m.group(1))
|
original_height = int(m.group(1))
|
||||||
print(f"original height: {original_height}")
|
logger.info(f"original height: {original_height}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_width_negative = int(m.group(1))
|
original_width_negative = int(m.group(1))
|
||||||
print(f"original width negative: {original_width_negative}")
|
logger.info(f"original width negative: {original_width_negative}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
original_height_negative = int(m.group(1))
|
original_height_negative = int(m.group(1))
|
||||||
print(f"original height negative: {original_height_negative}")
|
logger.info(f"original height negative: {original_height_negative}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
crop_top = int(m.group(1))
|
crop_top = int(m.group(1))
|
||||||
print(f"crop top: {crop_top}")
|
logger.info(f"crop top: {crop_top}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
crop_left = int(m.group(1))
|
crop_left = int(m.group(1))
|
||||||
print(f"crop left: {crop_left}")
|
logger.info(f"crop left: {crop_left}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
if m: # steps
|
if m: # steps
|
||||||
steps = max(1, min(1000, int(m.group(1))))
|
steps = max(1, min(1000, int(m.group(1))))
|
||||||
print(f"steps: {steps}")
|
logger.info(f"steps: {steps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||||
if m: # seed
|
if m: # seed
|
||||||
seeds = [int(d) for d in m.group(1).split(",")]
|
seeds = [int(d) for d in m.group(1).split(",")]
|
||||||
print(f"seeds: {seeds}")
|
logger.info(f"seeds: {seeds}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # scale
|
if m: # scale
|
||||||
scale = float(m.group(1))
|
scale = float(m.group(1))
|
||||||
print(f"scale: {scale}")
|
logger.info(f"scale: {scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||||
@@ -2671,25 +2677,25 @@ def main(args):
|
|||||||
negative_scale = None
|
negative_scale = None
|
||||||
else:
|
else:
|
||||||
negative_scale = float(m.group(1))
|
negative_scale = float(m.group(1))
|
||||||
print(f"negative scale: {negative_scale}")
|
logger.info(f"negative scale: {negative_scale}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # strength
|
if m: # strength
|
||||||
strength = float(m.group(1))
|
strength = float(m.group(1))
|
||||||
print(f"strength: {strength}")
|
logger.info(f"strength: {strength}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
if m: # negative prompt
|
if m: # negative prompt
|
||||||
negative_prompt = m.group(1)
|
negative_prompt = m.group(1)
|
||||||
print(f"negative prompt: {negative_prompt}")
|
logger.info(f"negative prompt: {negative_prompt}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||||
if m: # clip prompt
|
if m: # clip prompt
|
||||||
clip_prompt = m.group(1)
|
clip_prompt = m.group(1)
|
||||||
print(f"clip prompt: {clip_prompt}")
|
logger.info(f"clip prompt: {clip_prompt}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||||
@@ -2697,89 +2703,89 @@ def main(args):
|
|||||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||||
while len(network_muls) < len(networks):
|
while len(network_muls) < len(networks):
|
||||||
network_muls.append(network_muls[-1])
|
network_muls.append(network_muls[-1])
|
||||||
print(f"network mul: {network_muls}")
|
logger.info(f"network mul: {network_muls}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Deep Shrink
|
# Deep Shrink
|
||||||
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink depth 1
|
if m: # deep shrink depth 1
|
||||||
ds_depth_1 = int(m.group(1))
|
ds_depth_1 = int(m.group(1))
|
||||||
print(f"deep shrink depth 1: {ds_depth_1}")
|
logger.info(f"deep shrink depth 1: {ds_depth_1}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink timesteps 1
|
if m: # deep shrink timesteps 1
|
||||||
ds_timesteps_1 = int(m.group(1))
|
ds_timesteps_1 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink depth 2
|
if m: # deep shrink depth 2
|
||||||
ds_depth_2 = int(m.group(1))
|
ds_depth_2 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink depth 2: {ds_depth_2}")
|
logger.info(f"deep shrink depth 2: {ds_depth_2}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink timesteps 2
|
if m: # deep shrink timesteps 2
|
||||||
ds_timesteps_2 = int(m.group(1))
|
ds_timesteps_2 = int(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # deep shrink ratio
|
if m: # deep shrink ratio
|
||||||
ds_ratio = float(m.group(1))
|
ds_ratio = float(m.group(1))
|
||||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||||
print(f"deep shrink ratio: {ds_ratio}")
|
logger.info(f"deep shrink ratio: {ds_ratio}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Gradual Latent
|
# Gradual Latent
|
||||||
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent timesteps
|
if m: # gradual latent timesteps
|
||||||
gl_timesteps = int(m.group(1))
|
gl_timesteps = int(m.group(1))
|
||||||
print(f"gradual latent timesteps: {gl_timesteps}")
|
logger.info(f"gradual latent timesteps: {gl_timesteps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio
|
if m: # gradual latent ratio
|
||||||
gl_ratio = float(m.group(1))
|
gl_ratio = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio: {ds_ratio}")
|
logger.info(f"gradual latent ratio: {ds_ratio}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent every n steps
|
if m: # gradual latent every n steps
|
||||||
gl_every_n_steps = int(m.group(1))
|
gl_every_n_steps = int(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent every n steps: {gl_every_n_steps}")
|
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio step
|
if m: # gradual latent ratio step
|
||||||
gl_ratio_step = float(m.group(1))
|
gl_ratio_step = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio step: {gl_ratio_step}")
|
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent s noise
|
if m: # gradual latent s noise
|
||||||
gl_s_noise = float(m.group(1))
|
gl_s_noise = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent s noise: {gl_s_noise}")
|
logger.info(f"gradual latent s noise: {gl_s_noise}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent unsharp params
|
if m: # gradual latent unsharp params
|
||||||
gl_unsharp_params = m.group(1)
|
gl_unsharp_params = m.group(1)
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent unsharp params: {gl_unsharp_params}")
|
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
logger.error(f"{ex}")
|
||||||
|
|
||||||
# override Deep Shrink
|
# override Deep Shrink
|
||||||
if ds_depth_1 is not None:
|
if ds_depth_1 is not None:
|
||||||
@@ -2825,7 +2831,7 @@ def main(args):
|
|||||||
if seed is None:
|
if seed is None:
|
||||||
seed = seed_random.randint(0, 2**32 - 1)
|
seed = seed_random.randint(0, 2**32 - 1)
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
print(f"seed: {seed}")
|
logger.info(f"seed: {seed}")
|
||||||
|
|
||||||
# prepare init image, guide image and mask
|
# prepare init image, guide image and mask
|
||||||
init_image = mask_image = guide_image = None
|
init_image = mask_image = guide_image = None
|
||||||
@@ -2841,7 +2847,7 @@ def main(args):
|
|||||||
width = width - width % 32
|
width = width - width % 32
|
||||||
height = height - height % 32
|
height = height - height % 32
|
||||||
if width != init_image.size[0] or height != init_image.size[1]:
|
if width != init_image.size[0] or height != init_image.size[1]:
|
||||||
print(
|
logger.warning(
|
||||||
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2903,12 +2909,14 @@ def main(args):
|
|||||||
process_batch(batch_data, highres_fix)
|
process_batch(batch_data, highres_fix)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
|
|
||||||
print("done!")
|
logger.info("done!")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
add_logging_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
|
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -489,10 +489,10 @@ class PipelineLike:
|
|||||||
|
|
||||||
def set_gradual_latent(self, gradual_latent):
|
def set_gradual_latent(self, gradual_latent):
|
||||||
if gradual_latent is None:
|
if gradual_latent is None:
|
||||||
print("gradual_latent is disabled")
|
logger.info("gradual_latent is disabled")
|
||||||
self.gradual_latent = None
|
self.gradual_latent = None
|
||||||
else:
|
else:
|
||||||
print(f"gradual_latent is enabled: {gradual_latent}")
|
logger.info(f"gradual_latent is enabled: {gradual_latent}")
|
||||||
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
||||||
|
|
||||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||||
@@ -971,8 +971,8 @@ class PipelineLike:
|
|||||||
enable_gradual_latent = False
|
enable_gradual_latent = False
|
||||||
if self.gradual_latent:
|
if self.gradual_latent:
|
||||||
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
||||||
print("gradual_latent is not supported for this scheduler. Ignoring.")
|
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
|
||||||
print(self.scheduler.__class__.__name__)
|
logger.info(f'{self.scheduler.__class__.__name__}')
|
||||||
else:
|
else:
|
||||||
enable_gradual_latent = True
|
enable_gradual_latent = True
|
||||||
step_elapsed = 1000
|
step_elapsed = 1000
|
||||||
@@ -3314,42 +3314,42 @@ def main(args):
|
|||||||
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent timesteps
|
if m: # gradual latent timesteps
|
||||||
gl_timesteps = int(m.group(1))
|
gl_timesteps = int(m.group(1))
|
||||||
print(f"gradual latent timesteps: {gl_timesteps}")
|
logger.info(f"gradual latent timesteps: {gl_timesteps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio
|
if m: # gradual latent ratio
|
||||||
gl_ratio = float(m.group(1))
|
gl_ratio = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio: {ds_ratio}")
|
logger.info(f"gradual latent ratio: {ds_ratio}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent every n steps
|
if m: # gradual latent every n steps
|
||||||
gl_every_n_steps = int(m.group(1))
|
gl_every_n_steps = int(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent every n steps: {gl_every_n_steps}")
|
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio step
|
if m: # gradual latent ratio step
|
||||||
gl_ratio_step = float(m.group(1))
|
gl_ratio_step = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio step: {gl_ratio_step}")
|
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent s noise
|
if m: # gradual latent s noise
|
||||||
gl_s_noise = float(m.group(1))
|
gl_s_noise = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent s noise: {gl_s_noise}")
|
logger.info(f"gradual latent s noise: {gl_s_noise}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent unsharp params
|
if m: # gradual latent unsharp params
|
||||||
gl_unsharp_params = m.group(1)
|
gl_unsharp_params = m.group(1)
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent unsharp params: {gl_unsharp_params}")
|
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
@@ -3369,7 +3369,7 @@ def main(args):
|
|||||||
if gl_unsharp_params is not None:
|
if gl_unsharp_params is not None:
|
||||||
unsharp_params = gl_unsharp_params.split(",")
|
unsharp_params = gl_unsharp_params.split(",")
|
||||||
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
|
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
|
||||||
print(unsharp_params)
|
logger.info(f'{unsharp_params}')
|
||||||
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
|
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
|
||||||
us_ksize = int(us_ksize)
|
us_ksize = int(us_ksize)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,6 +3,11 @@ import gc
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .utils import setup_logging
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
HAS_CUDA = torch.cuda.is_available()
|
HAS_CUDA = torch.cuda.is_available()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -59,7 +64,7 @@ def get_preferred_device() -> torch.device:
|
|||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
print(f"get_preferred_device() -> {device}")
|
logger.info(f"get_preferred_device() -> {device}")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
@@ -77,8 +82,8 @@ def init_ipex():
|
|||||||
|
|
||||||
is_initialized, error_message = ipex_init()
|
is_initialized, error_message = ipex_init()
|
||||||
if not is_initialized:
|
if not is_initialized:
|
||||||
print("failed to initialize ipex:", error_message)
|
logger.error("failed to initialize ipex: {error_message}")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("failed to initialize ipex:", e)
|
logger.error("failed to initialize ipex: {e}")
|
||||||
|
|||||||
@@ -327,10 +327,10 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
for i, text_encoder in enumerate(text_encoders):
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
if len(text_encoders) > 1:
|
if len(text_encoders) > 1:
|
||||||
index = i + 1
|
index = i + 1
|
||||||
print(f"create LoRA for Text Encoder {index}")
|
logger.info(f"create LoRA for Text Encoder {index}")
|
||||||
else:
|
else:
|
||||||
index = None
|
index = None
|
||||||
print(f"create LoRA for Text Encoder")
|
logger.info("create LoRA for Text Encoder")
|
||||||
|
|
||||||
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
self.text_encoder_loras.extend(text_encoder_loras)
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
|
|||||||
@@ -380,10 +380,10 @@ class PipelineLike:
|
|||||||
|
|
||||||
def set_gradual_latent(self, gradual_latent):
|
def set_gradual_latent(self, gradual_latent):
|
||||||
if gradual_latent is None:
|
if gradual_latent is None:
|
||||||
print("gradual_latent is disabled")
|
logger.info("gradual_latent is disabled")
|
||||||
self.gradual_latent = None
|
self.gradual_latent = None
|
||||||
else:
|
else:
|
||||||
print(f"gradual_latent is enabled: {gradual_latent}")
|
logger.info(f"gradual_latent is enabled: {gradual_latent}")
|
||||||
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -789,8 +789,8 @@ class PipelineLike:
|
|||||||
enable_gradual_latent = False
|
enable_gradual_latent = False
|
||||||
if self.gradual_latent:
|
if self.gradual_latent:
|
||||||
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
||||||
print("gradual_latent is not supported for this scheduler. Ignoring.")
|
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
|
||||||
print(self.scheduler.__class__.__name__)
|
logger.info(f'{self.scheduler.__class__.__name__}')
|
||||||
else:
|
else:
|
||||||
enable_gradual_latent = True
|
enable_gradual_latent = True
|
||||||
step_elapsed = 1000
|
step_elapsed = 1000
|
||||||
@@ -2614,84 +2614,84 @@ def main(args):
|
|||||||
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent timesteps
|
if m: # gradual latent timesteps
|
||||||
gl_timesteps = int(m.group(1))
|
gl_timesteps = int(m.group(1))
|
||||||
print(f"gradual latent timesteps: {gl_timesteps}")
|
logger.info(f"gradual latent timesteps: {gl_timesteps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio
|
if m: # gradual latent ratio
|
||||||
gl_ratio = float(m.group(1))
|
gl_ratio = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio: {ds_ratio}")
|
logger.info(f"gradual latent ratio: {ds_ratio}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent every n steps
|
if m: # gradual latent every n steps
|
||||||
gl_every_n_steps = int(m.group(1))
|
gl_every_n_steps = int(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent every n steps: {gl_every_n_steps}")
|
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio step
|
if m: # gradual latent ratio step
|
||||||
gl_ratio_step = float(m.group(1))
|
gl_ratio_step = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio step: {gl_ratio_step}")
|
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent s noise
|
if m: # gradual latent s noise
|
||||||
gl_s_noise = float(m.group(1))
|
gl_s_noise = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent s noise: {gl_s_noise}")
|
logger.info(f"gradual latent s noise: {gl_s_noise}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent unsharp params
|
if m: # gradual latent unsharp params
|
||||||
gl_unsharp_params = m.group(1)
|
gl_unsharp_params = m.group(1)
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent unsharp params: {gl_unsharp_params}")
|
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Gradual Latent
|
# Gradual Latent
|
||||||
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent timesteps
|
if m: # gradual latent timesteps
|
||||||
gl_timesteps = int(m.group(1))
|
gl_timesteps = int(m.group(1))
|
||||||
print(f"gradual latent timesteps: {gl_timesteps}")
|
logger.info(f"gradual latent timesteps: {gl_timesteps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio
|
if m: # gradual latent ratio
|
||||||
gl_ratio = float(m.group(1))
|
gl_ratio = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio: {ds_ratio}")
|
logger.info(f"gradual latent ratio: {ds_ratio}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent every n steps
|
if m: # gradual latent every n steps
|
||||||
gl_every_n_steps = int(m.group(1))
|
gl_every_n_steps = int(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent every n steps: {gl_every_n_steps}")
|
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent ratio step
|
if m: # gradual latent ratio step
|
||||||
gl_ratio_step = float(m.group(1))
|
gl_ratio_step = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent ratio step: {gl_ratio_step}")
|
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent s noise
|
if m: # gradual latent s noise
|
||||||
gl_s_noise = float(m.group(1))
|
gl_s_noise = float(m.group(1))
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent s noise: {gl_s_noise}")
|
logger.info(f"gradual latent s noise: {gl_s_noise}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||||
if m: # gradual latent unsharp params
|
if m: # gradual latent unsharp params
|
||||||
gl_unsharp_params = m.group(1)
|
gl_unsharp_params = m.group(1)
|
||||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||||
print(f"gradual latent unsharp params: {gl_unsharp_params}")
|
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
|
|||||||
Reference in New Issue
Block a user