mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Merge pull request #1117 from kohya-ss/gradual_latent_hires_fix
Gradual latent hires fix
This commit is contained in:
32
README.md
32
README.md
@@ -1,3 +1,35 @@
|
||||
# Gradual Latent について
|
||||
|
||||
latentのサイズを徐々に大きくしていくHires fixです。`sdxl_gen_img.py` に以下のオプションが追加されています。
|
||||
|
||||
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
|
||||
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
|
||||
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
|
||||
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
|
||||
|
||||
それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。
|
||||
|
||||
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
|
||||
|
||||
`gen_img_diffusers.py` にも同様のオプションが追加されていますが、試した範囲ではどうやっても乱れた画像しか生成できませんでした。
|
||||
|
||||
# About Gradual Latent
|
||||
|
||||
Gradual Latent is a Hires fix that gradually increases the size of the latent. `sdxl_gen_img.py` has the following options added.
|
||||
|
||||
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
|
||||
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
|
||||
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
||||
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
||||
|
||||
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
||||
|
||||
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
|
||||
|
||||
`gen_img_diffusers.py` also has the same options, but in the range I tried, it only generated distorted images no matter what I did.
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
3326
gen_img.py
Normal file
3326
gen_img.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -102,11 +102,14 @@ import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
|
||||
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# scheduler:
|
||||
@@ -453,6 +456,8 @@ class PipelineLike:
|
||||
self.control_nets: List[ControlNetInfo] = []
|
||||
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
|
||||
|
||||
self.gradual_latent: GradualLatent = None
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
@@ -483,6 +488,14 @@ class PipelineLike:
|
||||
def set_control_nets(self, ctrl_nets):
|
||||
self.control_nets = ctrl_nets
|
||||
|
||||
def set_gradual_latent(self, gradual_latent):
|
||||
if gradual_latent is None:
|
||||
print("gradual_latent is disabled")
|
||||
self.gradual_latent = None
|
||||
else:
|
||||
print(f"gradual_latent is enabled: {gradual_latent}")
|
||||
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
||||
|
||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
@@ -957,7 +970,49 @@ class PipelineLike:
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
|
||||
enable_gradual_latent = False
|
||||
if self.gradual_latent:
|
||||
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
||||
print("gradual_latent is not supported for this scheduler. Ignoring.")
|
||||
print(self.scheduler.__class__.__name__)
|
||||
else:
|
||||
enable_gradual_latent = True
|
||||
step_elapsed = 1000
|
||||
current_ratio = self.gradual_latent.ratio
|
||||
|
||||
# first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする
|
||||
height, width = latents.shape[-2:]
|
||||
org_dtype = latents.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
latents = latents.float()
|
||||
latents = torch.nn.functional.interpolate(
|
||||
latents, scale_factor=current_ratio, mode="bicubic", align_corners=False
|
||||
).to(org_dtype)
|
||||
|
||||
# apply unsharp mask / アンシャープマスクを適用する
|
||||
if self.gradual_latent.gaussian_blur_ksize:
|
||||
latents = self.gradual_latent.apply_unshark_mask(latents)
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
resized_size = None
|
||||
if enable_gradual_latent:
|
||||
# gradually upscale the latents / latentsを徐々にアップスケールする
|
||||
if (
|
||||
t < self.gradual_latent.start_timesteps
|
||||
and current_ratio < 1.0
|
||||
and step_elapsed >= self.gradual_latent.every_n_steps
|
||||
):
|
||||
current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0)
|
||||
# make divisible by 8 because size of latents must be divisible at bottom of UNet
|
||||
h = int(height * current_ratio) // 8 * 8
|
||||
w = int(width * current_ratio) // 8 * 8
|
||||
resized_size = (h, w)
|
||||
self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent)
|
||||
step_elapsed = 0
|
||||
else:
|
||||
self.scheduler.set_gradual_latent_params(None, None)
|
||||
step_elapsed += 1
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -1539,7 +1594,9 @@ class PipelineLike:
|
||||
image_embeddings = self.vgg16_feat_model(image)["feat"]
|
||||
|
||||
# バッチサイズが複数だと正しく動くかわからない
|
||||
loss = ((image_embeddings - guide_embeddings) ** 2).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので
|
||||
loss = (
|
||||
(image_embeddings - guide_embeddings) ** 2
|
||||
).mean() * guidance_scale # MSE style transferでコンテンツの損失はMSEなので
|
||||
|
||||
grads = -torch.autograd.grad(loss, latents)[0]
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
@@ -2130,6 +2187,7 @@ class BatchDataBase(NamedTuple):
|
||||
mask_image: Any
|
||||
clip_prompt: str
|
||||
guide_image: Any
|
||||
raw_prompt: str
|
||||
|
||||
|
||||
class BatchDataExt(NamedTuple):
|
||||
@@ -2249,7 +2307,7 @@ def main(args):
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
|
||||
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
scheduler_cls = EulerAncestralDiscreteSchedulerGL
|
||||
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
@@ -2512,6 +2570,29 @@ def main(args):
|
||||
if args.ds_depth_1 is not None:
|
||||
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
|
||||
|
||||
# Gradual Latent
|
||||
if args.gradual_latent_timesteps is not None:
|
||||
if args.gradual_latent_unsharp_params:
|
||||
us_params = args.gradual_latent_unsharp_params.split(",")
|
||||
us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]]
|
||||
us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3]))
|
||||
us_ksize = int(us_ksize)
|
||||
else:
|
||||
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
|
||||
|
||||
gradual_latent = GradualLatent(
|
||||
args.gradual_latent_ratio,
|
||||
args.gradual_latent_timesteps,
|
||||
args.gradual_latent_every_n_steps,
|
||||
args.gradual_latent_ratio_step,
|
||||
args.gradual_latent_s_noise,
|
||||
us_ksize,
|
||||
us_sigma,
|
||||
us_strength,
|
||||
us_target_x,
|
||||
)
|
||||
pipe.set_gradual_latent(gradual_latent)
|
||||
|
||||
# Extended Textual Inversion および Textual Inversionを処理する
|
||||
if args.XTI_embeddings:
|
||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
@@ -2533,7 +2614,9 @@ def main(args):
|
||||
embeds = next(iter(data.values()))
|
||||
|
||||
if type(embeds) != torch.Tensor:
|
||||
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
|
||||
raise ValueError(
|
||||
f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}"
|
||||
)
|
||||
|
||||
num_vectors_per_token = embeds.size()[0]
|
||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||
@@ -2633,7 +2716,7 @@ def main(args):
|
||||
logger.info(f"reading prompts from {args.from_file}")
|
||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||
prompt_list = f.read().splitlines()
|
||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
|
||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
|
||||
elif args.prompt is not None:
|
||||
prompt_list = [args.prompt]
|
||||
else:
|
||||
@@ -2763,7 +2846,9 @@ def main(args):
|
||||
|
||||
logger.info(f"loaded {len(guide_images)} guide images for guidance")
|
||||
if len(guide_images) == 0:
|
||||
logger.info(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
||||
logger.info(
|
||||
f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}"
|
||||
)
|
||||
guide_images = None
|
||||
else:
|
||||
guide_images = None
|
||||
@@ -2877,13 +2962,14 @@ def main(args):
|
||||
# このバッチの情報を取り出す
|
||||
(
|
||||
return_latents,
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
|
||||
) = batch[0]
|
||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||
|
||||
prompts = []
|
||||
negative_prompts = []
|
||||
raw_prompts = []
|
||||
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
|
||||
noises = [
|
||||
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
|
||||
@@ -2914,11 +3000,16 @@ def main(args):
|
||||
all_images_are_same = True
|
||||
all_masks_are_same = True
|
||||
all_guide_images_are_same = True
|
||||
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
||||
for i, (
|
||||
_,
|
||||
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
|
||||
_,
|
||||
) in enumerate(batch):
|
||||
prompts.append(prompt)
|
||||
negative_prompts.append(negative_prompt)
|
||||
seeds.append(seed)
|
||||
clip_prompts.append(clip_prompt)
|
||||
raw_prompts.append(raw_prompt)
|
||||
|
||||
if init_image is not None:
|
||||
init_images.append(init_image)
|
||||
@@ -3010,8 +3101,8 @@ def main(args):
|
||||
# save image
|
||||
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(
|
||||
zip(images, prompts, negative_prompts, seeds, clip_prompts)
|
||||
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
|
||||
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
|
||||
):
|
||||
if highres_fix:
|
||||
seed -= 1 # record original seed
|
||||
@@ -3027,6 +3118,8 @@ def main(args):
|
||||
metadata.add_text("negative-scale", str(negative_scale))
|
||||
if clip_prompt is not None:
|
||||
metadata.add_text("clip-prompt", clip_prompt)
|
||||
if raw_prompt is not None:
|
||||
metadata.add_text("raw-prompt", raw_prompt)
|
||||
|
||||
if args.use_original_file_name and init_images is not None:
|
||||
if type(init_images) is list:
|
||||
@@ -3049,7 +3142,9 @@ def main(args):
|
||||
cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
except ImportError:
|
||||
logger.info("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
|
||||
logger.info(
|
||||
"opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません"
|
||||
)
|
||||
|
||||
return images
|
||||
|
||||
@@ -3104,6 +3199,14 @@ def main(args):
|
||||
ds_timesteps_2 = args.ds_timesteps_2
|
||||
ds_ratio = args.ds_ratio
|
||||
|
||||
# Gradual Latent
|
||||
gl_timesteps = None # means no override
|
||||
gl_ratio = args.gradual_latent_ratio
|
||||
gl_every_n_steps = args.gradual_latent_every_n_steps
|
||||
gl_ratio_step = args.gradual_latent_ratio_step
|
||||
gl_s_noise = args.gradual_latent_s_noise
|
||||
gl_unsharp_params = args.gradual_latent_unsharp_params
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
@@ -3206,10 +3309,52 @@ def main(args):
|
||||
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink ratio
|
||||
ds_ratio = float(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
logger.info(f"deep shrink ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
# Gradual Latent
|
||||
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent timesteps
|
||||
gl_timesteps = int(m.group(1))
|
||||
print(f"gradual latent timesteps: {gl_timesteps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent ratio
|
||||
gl_ratio = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent every n steps
|
||||
gl_every_n_steps = int(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent every n steps: {gl_every_n_steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent ratio step
|
||||
gl_ratio_step = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent ratio step: {gl_ratio_step}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent s noise
|
||||
gl_s_noise = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent s noise: {gl_s_noise}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent unsharp params
|
||||
gl_unsharp_params = m.group(1)
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
logger.info(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.info(ex)
|
||||
@@ -3220,6 +3365,31 @@ def main(args):
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||
|
||||
# override Gradual Latent
|
||||
if gl_timesteps is not None:
|
||||
if gl_timesteps < 0:
|
||||
gl_timesteps = args.gradual_latent_timesteps or 650
|
||||
if gl_unsharp_params is not None:
|
||||
unsharp_params = gl_unsharp_params.split(",")
|
||||
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
|
||||
print(unsharp_params)
|
||||
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
|
||||
us_ksize = int(us_ksize)
|
||||
else:
|
||||
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
|
||||
gradual_latent = GradualLatent(
|
||||
gl_ratio,
|
||||
gl_timesteps,
|
||||
gl_every_n_steps,
|
||||
gl_ratio_step,
|
||||
gl_s_noise,
|
||||
us_ksize,
|
||||
us_sigma,
|
||||
us_strength,
|
||||
us_target_x,
|
||||
)
|
||||
pipe.set_gradual_latent(gradual_latent)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
@@ -3287,7 +3457,9 @@ def main(args):
|
||||
|
||||
b1 = BatchData(
|
||||
False,
|
||||
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
BatchDataBase(
|
||||
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
|
||||
),
|
||||
BatchDataExt(
|
||||
width,
|
||||
height,
|
||||
@@ -3322,16 +3494,25 @@ def main(args):
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む")
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v_parameterization", action="store_true", help="enable v-parameterization training / v-parameterization学習を有効にする"
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
|
||||
parser.add_argument(
|
||||
"--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む"
|
||||
"--from_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)"
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない"
|
||||
@@ -3343,7 +3524,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength")
|
||||
parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数")
|
||||
parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先")
|
||||
parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする")
|
||||
parser.add_argument(
|
||||
"--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_original_file_name",
|
||||
action="store_true",
|
||||
@@ -3397,9 +3580,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale",
|
||||
)
|
||||
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
||||
parser.add_argument(
|
||||
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
||||
"--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_cache_dir",
|
||||
@@ -3435,25 +3623,46 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する"
|
||||
"--opt_channels_last",
|
||||
action="store_true",
|
||||
help="set channels last option to model / モデルにchannels lastを指定し最適化する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名"
|
||||
"--network_module",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
|
||||
)
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
|
||||
parser.add_argument(
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||
"--network_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
||||
)
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
"--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models",
|
||||
type=int,
|
||||
default=None,
|
||||
help="merge this number of networks / この数だけネットワークをマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_pre_calc",
|
||||
action="store_true",
|
||||
help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_regional_mask_max_color_codes",
|
||||
@@ -3475,7 +3684,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
nargs="*",
|
||||
help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
|
||||
)
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
|
||||
parser.add_argument(
|
||||
"--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
@@ -3516,7 +3727,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
|
||||
"--highres_fix_steps",
|
||||
type=int,
|
||||
default=28,
|
||||
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_strength",
|
||||
@@ -3525,7 +3739,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
|
||||
"--highres_fix_save_1st",
|
||||
action="store_true",
|
||||
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_latents_upscaling",
|
||||
@@ -3533,7 +3749,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
|
||||
"--highres_fix_upscaler",
|
||||
type=str,
|
||||
default=None,
|
||||
help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_upscaler_args",
|
||||
@@ -3548,14 +3767,21 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
|
||||
"--negative_scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
|
||||
"--control_net_preps",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名",
|
||||
)
|
||||
parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み")
|
||||
parser.add_argument(
|
||||
@@ -3593,6 +3819,45 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
|
||||
)
|
||||
|
||||
# gradual latent
|
||||
parser.add_argument(
|
||||
"--gradual_latent_timesteps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_ratio",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_ratio_step",
|
||||
type=float,
|
||||
default=0.125,
|
||||
help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_every_n_steps",
|
||||
type=int,
|
||||
default=3,
|
||||
help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_s_noise",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="s_noise for Gradual Latent / Gradual Latentのs_noise",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_unsharp_params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
|
||||
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -3600,4 +3865,5 @@ if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
setup_logging(args, reset=True)
|
||||
main(args)
|
||||
|
||||
192
library/utils.py
192
library/utils.py
@@ -1,7 +1,12 @@
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from typing import *
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
@@ -72,3 +77,190 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
if msg_init is not None:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
|
||||
# region Gradual Latent hires fix
|
||||
|
||||
|
||||
class GradualLatent:
|
||||
def __init__(
|
||||
self,
|
||||
ratio,
|
||||
start_timesteps,
|
||||
every_n_steps,
|
||||
ratio_step,
|
||||
s_noise=1.0,
|
||||
gaussian_blur_ksize=None,
|
||||
gaussian_blur_sigma=0.5,
|
||||
gaussian_blur_strength=0.5,
|
||||
unsharp_target_x=True,
|
||||
):
|
||||
self.ratio = ratio
|
||||
self.start_timesteps = start_timesteps
|
||||
self.every_n_steps = every_n_steps
|
||||
self.ratio_step = ratio_step
|
||||
self.s_noise = s_noise
|
||||
self.gaussian_blur_ksize = gaussian_blur_ksize
|
||||
self.gaussian_blur_sigma = gaussian_blur_sigma
|
||||
self.gaussian_blur_strength = gaussian_blur_strength
|
||||
self.unsharp_target_x = unsharp_target_x
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
||||
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
||||
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
||||
+ f"unsharp_target_x={self.unsharp_target_x})"
|
||||
)
|
||||
|
||||
def apply_unshark_mask(self, x: torch.Tensor):
|
||||
if self.gaussian_blur_ksize is None:
|
||||
return x
|
||||
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
||||
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
|
||||
mask = (x - blurred) * self.gaussian_blur_strength
|
||||
sharpened = x + mask
|
||||
return sharpened
|
||||
|
||||
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
||||
org_dtype = x.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.float()
|
||||
|
||||
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
||||
|
||||
# apply unsharp mask / アンシャープマスクを適用する
|
||||
if unsharp and self.gaussian_blur_ksize:
|
||||
x = self.apply_unshark_mask(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.resized_size = None
|
||||
self.gradual_latent = None
|
||||
|
||||
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
||||
self.resized_size = size
|
||||
self.gradual_latent = gradual_latent
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`,
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
||||
otherwise a tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
# logger.warning(
|
||||
print(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
elif self.config.prediction_type == "sample":
|
||||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||||
else:
|
||||
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
|
||||
dt = sigma_down - sigma
|
||||
|
||||
device = model_output.device
|
||||
if self.resized_size is None:
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
||||
)
|
||||
s_noise = 1.0
|
||||
else:
|
||||
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
||||
s_noise = self.gradual_latent.s_noise
|
||||
|
||||
if self.gradual_latent.unsharp_target_x:
|
||||
prev_sample = sample + derivative * dt
|
||||
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
||||
else:
|
||||
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
||||
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
||||
dtype=model_output.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
prev_sample = prev_sample + noise * sigma_up * s_noise
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -12,8 +12,10 @@ import numpy as np
|
||||
import torch
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
@@ -248,7 +250,8 @@ class LoRAInfModule(LoRAModule):
|
||||
if mask is None:
|
||||
# raise ValueError(f"mask is None for resolution {area}")
|
||||
# emb_layers in SDXL doesn't have mask
|
||||
# logger.info(f"mask is None for resolution {area}, {x.size()}")
|
||||
# if "emb" not in self.lora_name:
|
||||
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
|
||||
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
||||
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
||||
if len(x.size()) != 4:
|
||||
@@ -265,7 +268,9 @@ class LoRAInfModule(LoRAModule):
|
||||
# apply mask for LoRA result
|
||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
mask = self.get_mask_for_x(lx)
|
||||
# logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}")
|
||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
|
||||
# mask = mask.squeeze(-1)
|
||||
lx = lx * mask
|
||||
|
||||
x = self.org_forward(x)
|
||||
@@ -514,7 +519,9 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
logger.warning(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
)
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -792,7 +799,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
logger.info(f"create LoRA network from block_dims")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
logger.info(f"block_dims: {block_dims}")
|
||||
logger.info(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
@@ -800,9 +809,13 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -929,6 +942,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
@@ -1116,7 +1133,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.set_network(self)
|
||||
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
|
||||
self.batch_size = batch_size
|
||||
self.num_sub_prompts = num_sub_prompts
|
||||
self.current_size = (height, width)
|
||||
@@ -1142,6 +1159,13 @@ class LoRANetwork(torch.nn.Module):
|
||||
resize_add(h, w)
|
||||
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
||||
resize_add(h + h % 2, w + w % 2)
|
||||
|
||||
# deep shrink
|
||||
if ds_ratio is not None:
|
||||
hd = int(h * ds_ratio)
|
||||
wd = int(w * ds_ratio)
|
||||
resize_add(hd, wd)
|
||||
|
||||
h = (h + 1) // 2
|
||||
w = (w + 1) // 2
|
||||
|
||||
|
||||
451
sdxl_gen_img.py
451
sdxl_gen_img.py
@@ -55,9 +55,12 @@ from networks.lora import LoRANetwork
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from networks.control_net_lllite import ControlNetLLLite
|
||||
from library.utils import setup_logging
|
||||
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# scheduler:
|
||||
@@ -344,6 +347,8 @@ class PipelineLike:
|
||||
self.control_nets: List[ControlNetLLLite] = []
|
||||
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
|
||||
|
||||
self.gradual_latent: GradualLatent = None
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids):
|
||||
self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids
|
||||
@@ -374,6 +379,14 @@ class PipelineLike:
|
||||
def set_control_nets(self, ctrl_nets):
|
||||
self.control_nets = ctrl_nets
|
||||
|
||||
def set_gradual_latent(self, gradual_latent):
|
||||
if gradual_latent is None:
|
||||
print("gradual_latent is disabled")
|
||||
self.gradual_latent = None
|
||||
else:
|
||||
print(f"gradual_latent is enabled: {gradual_latent}")
|
||||
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -708,7 +721,116 @@ class PipelineLike:
|
||||
control_net.set_cond_image(None)
|
||||
|
||||
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
|
||||
|
||||
# # first, we downscale the latents to the half of the size
|
||||
# # 最初に1/2に縮小する
|
||||
# height, width = latents.shape[-2:]
|
||||
# # latents = torch.nn.functional.interpolate(latents.float(), scale_factor=0.5, mode="bicubic", align_corners=False).to(
|
||||
# # latents.dtype
|
||||
# # )
|
||||
# latents = latents[:, :, ::2, ::2]
|
||||
# current_scale = 0.5
|
||||
|
||||
# # how much to increase the scale at each step: .125 seems to work well (because it's 1/8?)
|
||||
# # 各ステップに拡大率をどのくらい増やすか:.125がよさそう(たぶん1/8なので)
|
||||
# scale_step = 0.125
|
||||
|
||||
# # timesteps at which to start increasing the scale: 1000 seems to be enough
|
||||
# # 拡大を開始するtimesteps: 1000で十分そうである
|
||||
# start_timesteps = 1000
|
||||
|
||||
# # how many steps to wait before increasing the scale again
|
||||
# # small values leads to blurry images (because the latents are blurry after the upscale, so some denoising might be needed)
|
||||
# # large values leads to flat images
|
||||
|
||||
# # 何ステップごとに拡大するか
|
||||
# # 小さいとボケる(拡大後のlatentsはボケた感じになるので、そこから数stepのdenoiseが必要と思われる)
|
||||
# # 大きすぎると細部が書き込まれずのっぺりした感じになる
|
||||
# every_n_steps = 5
|
||||
|
||||
# scale_step = input("scale step:")
|
||||
# scale_step = float(scale_step)
|
||||
# start_timesteps = input("start timesteps:")
|
||||
# start_timesteps = int(start_timesteps)
|
||||
# every_n_steps = input("every n steps:")
|
||||
# every_n_steps = int(every_n_steps)
|
||||
|
||||
# # for i, t in enumerate(tqdm(timesteps)):
|
||||
# i = 0
|
||||
# last_step = 0
|
||||
# while i < len(timesteps):
|
||||
# t = timesteps[i]
|
||||
# print(f"[{i}] t={t}")
|
||||
|
||||
# print(i, t, current_scale, latents.shape)
|
||||
# if t < start_timesteps and current_scale < 1.0 and i % every_n_steps == 0:
|
||||
# if i == last_step:
|
||||
# pass
|
||||
# else:
|
||||
# print("upscale")
|
||||
# current_scale = min(current_scale + scale_step, 1.0)
|
||||
|
||||
# h = int(height * current_scale) // 8 * 8
|
||||
# w = int(width * current_scale) // 8 * 8
|
||||
|
||||
# latents = torch.nn.functional.interpolate(latents.float(), size=(h, w), mode="bicubic", align_corners=False).to(
|
||||
# latents.dtype
|
||||
# )
|
||||
# last_step = i
|
||||
# i = max(0, i - every_n_steps + 1)
|
||||
|
||||
# diff = timesteps[i] - timesteps[last_step]
|
||||
# # resized_init_noise = torch.nn.functional.interpolate(
|
||||
# # init_noise.float(), size=(h, w), mode="bicubic", align_corners=False
|
||||
# # ).to(latents.dtype)
|
||||
# # latents = self.scheduler.add_noise(latents, resized_init_noise, diff)
|
||||
# latents = self.scheduler.add_noise(latents, torch.randn_like(latents), diff * 4)
|
||||
# # latents += torch.randn_like(latents) / 100 * diff
|
||||
# continue
|
||||
|
||||
enable_gradual_latent = False
|
||||
if self.gradual_latent:
|
||||
if not hasattr(self.scheduler, "set_gradual_latent_params"):
|
||||
print("gradual_latent is not supported for this scheduler. Ignoring.")
|
||||
print(self.scheduler.__class__.__name__)
|
||||
else:
|
||||
enable_gradual_latent = True
|
||||
step_elapsed = 1000
|
||||
current_ratio = self.gradual_latent.ratio
|
||||
|
||||
# first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする
|
||||
height, width = latents.shape[-2:]
|
||||
org_dtype = latents.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
latents = latents.float()
|
||||
latents = torch.nn.functional.interpolate(
|
||||
latents, scale_factor=current_ratio, mode="bicubic", align_corners=False
|
||||
).to(org_dtype)
|
||||
|
||||
# apply unsharp mask / アンシャープマスクを適用する
|
||||
if self.gradual_latent.gaussian_blur_ksize:
|
||||
latents = self.gradual_latent.apply_unshark_mask(latents)
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
resized_size = None
|
||||
if enable_gradual_latent:
|
||||
# gradually upscale the latents / latentsを徐々にアップスケールする
|
||||
if (
|
||||
t < self.gradual_latent.start_timesteps
|
||||
and current_ratio < 1.0
|
||||
and step_elapsed >= self.gradual_latent.every_n_steps
|
||||
):
|
||||
current_ratio = min(current_ratio + self.gradual_latent.ratio_step, 1.0)
|
||||
# make divisible by 8 because size of latents must be divisible at bottom of UNet
|
||||
h = int(height * current_ratio) // 8 * 8
|
||||
w = int(width * current_ratio) // 8 * 8
|
||||
resized_size = (h, w)
|
||||
self.scheduler.set_gradual_latent_params(resized_size, self.gradual_latent)
|
||||
step_elapsed = 0
|
||||
else:
|
||||
self.scheduler.set_gradual_latent_params(None, None)
|
||||
step_elapsed += 1
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
@@ -777,6 +899,8 @@ class PipelineLike:
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
i += 1
|
||||
|
||||
if return_latents:
|
||||
return latents
|
||||
|
||||
@@ -1310,7 +1434,6 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# def load_clip_l14_336(dtype):
|
||||
# logger.info(f"loading CLIP: {CLIP_ID_L14_336}")
|
||||
# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
|
||||
@@ -1327,6 +1450,7 @@ class BatchDataBase(NamedTuple):
|
||||
mask_image: Any
|
||||
clip_prompt: str
|
||||
guide_image: Any
|
||||
raw_prompt: str
|
||||
|
||||
|
||||
class BatchDataExt(NamedTuple):
|
||||
@@ -1410,7 +1534,7 @@ def main(args):
|
||||
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
scheduler_cls = EulerAncestralDiscreteSchedulerGL
|
||||
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
has_clip_sample = False
|
||||
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
|
||||
@@ -1707,6 +1831,29 @@ def main(args):
|
||||
if args.ds_depth_1 is not None:
|
||||
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
|
||||
|
||||
# Gradual Latent
|
||||
if args.gradual_latent_timesteps is not None:
|
||||
if args.gradual_latent_unsharp_params:
|
||||
us_params = args.gradual_latent_unsharp_params.split(",")
|
||||
us_ksize, us_sigma, us_strength = [float(v) for v in us_params[:3]]
|
||||
us_target_x = True if len(us_params) <= 3 else bool(int(us_params[3]))
|
||||
us_ksize = int(us_ksize)
|
||||
else:
|
||||
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
|
||||
|
||||
gradual_latent = GradualLatent(
|
||||
args.gradual_latent_ratio,
|
||||
args.gradual_latent_timesteps,
|
||||
args.gradual_latent_every_n_steps,
|
||||
args.gradual_latent_ratio_step,
|
||||
args.gradual_latent_s_noise,
|
||||
us_ksize,
|
||||
us_sigma,
|
||||
us_strength,
|
||||
us_target_x,
|
||||
)
|
||||
pipe.set_gradual_latent(gradual_latent)
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds1 = []
|
||||
@@ -1773,7 +1920,7 @@ def main(args):
|
||||
logger.info(f"reading prompts from {args.from_file}")
|
||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||
prompt_list = f.read().splitlines()
|
||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0]
|
||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
|
||||
elif args.prompt is not None:
|
||||
prompt_list = [args.prompt]
|
||||
else:
|
||||
@@ -1916,7 +2063,9 @@ def main(args):
|
||||
|
||||
logger.info(f"loaded {len(guide_images)} guide images for guidance")
|
||||
if len(guide_images) == 0:
|
||||
logger.warning(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
||||
logger.warning(
|
||||
f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}"
|
||||
)
|
||||
guide_images = None
|
||||
else:
|
||||
guide_images = None
|
||||
@@ -2045,7 +2194,7 @@ def main(args):
|
||||
# このバッチの情報を取り出す
|
||||
(
|
||||
return_latents,
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
|
||||
(
|
||||
width,
|
||||
height,
|
||||
@@ -2067,6 +2216,7 @@ def main(args):
|
||||
|
||||
prompts = []
|
||||
negative_prompts = []
|
||||
raw_prompts = []
|
||||
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
|
||||
noises = [
|
||||
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
|
||||
@@ -2097,11 +2247,16 @@ def main(args):
|
||||
all_images_are_same = True
|
||||
all_masks_are_same = True
|
||||
all_guide_images_are_same = True
|
||||
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch):
|
||||
for i, (
|
||||
_,
|
||||
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
|
||||
_,
|
||||
) in enumerate(batch):
|
||||
prompts.append(prompt)
|
||||
negative_prompts.append(negative_prompt)
|
||||
seeds.append(seed)
|
||||
clip_prompts.append(clip_prompt)
|
||||
raw_prompts.append(raw_prompt)
|
||||
|
||||
if init_image is not None:
|
||||
init_images.append(init_image)
|
||||
@@ -2199,8 +2354,8 @@ def main(args):
|
||||
# save image
|
||||
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate(
|
||||
zip(images, prompts, negative_prompts, seeds, clip_prompts)
|
||||
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
|
||||
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
|
||||
):
|
||||
if highres_fix:
|
||||
seed -= 1 # record original seed
|
||||
@@ -2216,6 +2371,8 @@ def main(args):
|
||||
metadata.add_text("negative-scale", str(negative_scale))
|
||||
if clip_prompt is not None:
|
||||
metadata.add_text("clip-prompt", clip_prompt)
|
||||
if raw_prompt is not None:
|
||||
metadata.add_text("raw-prompt", raw_prompt)
|
||||
metadata.add_text("original-height", str(original_height))
|
||||
metadata.add_text("original-width", str(original_width))
|
||||
metadata.add_text("original-height-negative", str(original_height_negative))
|
||||
@@ -2244,7 +2401,9 @@ def main(args):
|
||||
cv2.waitKey()
|
||||
cv2.destroyAllWindows()
|
||||
except ImportError:
|
||||
logger.error("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません")
|
||||
logger.error(
|
||||
"opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません"
|
||||
)
|
||||
|
||||
return images
|
||||
|
||||
@@ -2305,6 +2464,14 @@ def main(args):
|
||||
ds_timesteps_2 = args.ds_timesteps_2
|
||||
ds_ratio = args.ds_ratio
|
||||
|
||||
# Gradual Latent
|
||||
gl_timesteps = None # means no override
|
||||
gl_ratio = args.gradual_latent_ratio
|
||||
gl_every_n_steps = args.gradual_latent_every_n_steps
|
||||
gl_ratio_step = args.gradual_latent_ratio_step
|
||||
gl_s_noise = args.gradual_latent_s_noise
|
||||
gl_unsharp_params = args.gradual_latent_unsharp_params
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
@@ -2447,6 +2614,90 @@ def main(args):
|
||||
logger.info(f"deep shrink ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
# Gradual Latent
|
||||
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent timesteps
|
||||
gl_timesteps = int(m.group(1))
|
||||
print(f"gradual latent timesteps: {gl_timesteps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent ratio
|
||||
gl_ratio = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent every n steps
|
||||
gl_every_n_steps = int(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent every n steps: {gl_every_n_steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent ratio step
|
||||
gl_ratio_step = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent ratio step: {gl_ratio_step}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent s noise
|
||||
gl_s_noise = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent s noise: {gl_s_noise}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent unsharp params
|
||||
gl_unsharp_params = m.group(1)
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||
continue
|
||||
|
||||
# Gradual Latent
|
||||
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent timesteps
|
||||
gl_timesteps = int(m.group(1))
|
||||
print(f"gradual latent timesteps: {gl_timesteps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent ratio
|
||||
gl_ratio = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent every n steps
|
||||
gl_every_n_steps = int(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent every n steps: {gl_every_n_steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent ratio step
|
||||
gl_ratio_step = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent ratio step: {gl_ratio_step}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent s noise
|
||||
gl_s_noise = float(m.group(1))
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent s noise: {gl_s_noise}")
|
||||
continue
|
||||
|
||||
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # gradual latent unsharp params
|
||||
gl_unsharp_params = m.group(1)
|
||||
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
|
||||
print(f"gradual latent unsharp params: {gl_unsharp_params}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.error(f"{ex}")
|
||||
@@ -2457,6 +2708,30 @@ def main(args):
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||
|
||||
# override Gradual Latent
|
||||
if gl_timesteps is not None:
|
||||
if gl_timesteps < 0:
|
||||
gl_timesteps = args.gradual_latent_timesteps or 650
|
||||
if gl_unsharp_params is not None:
|
||||
unsharp_params = gl_unsharp_params.split(",")
|
||||
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
|
||||
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
|
||||
us_ksize = int(us_ksize)
|
||||
else:
|
||||
us_ksize, us_sigma, us_strength, us_target_x = None, None, None, None
|
||||
gradual_latent = GradualLatent(
|
||||
gl_ratio,
|
||||
gl_timesteps,
|
||||
gl_every_n_steps,
|
||||
gl_ratio_step,
|
||||
gl_s_noise,
|
||||
us_ksize,
|
||||
us_sigma,
|
||||
us_strength,
|
||||
us_target_x,
|
||||
)
|
||||
pipe.set_gradual_latent(gradual_latent)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
@@ -2518,7 +2793,9 @@ def main(args):
|
||||
|
||||
b1 = BatchData(
|
||||
False,
|
||||
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
BatchDataBase(
|
||||
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
|
||||
),
|
||||
BatchDataExt(
|
||||
width,
|
||||
height,
|
||||
@@ -2559,12 +2836,19 @@ def main(args):
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
|
||||
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
|
||||
parser.add_argument(
|
||||
"--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む"
|
||||
"--from_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)"
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない"
|
||||
@@ -2576,7 +2860,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength")
|
||||
parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数")
|
||||
parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先")
|
||||
parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする")
|
||||
parser.add_argument(
|
||||
"--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_original_file_name",
|
||||
action="store_true",
|
||||
@@ -2587,10 +2873,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ")
|
||||
parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅")
|
||||
parser.add_argument(
|
||||
"--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値"
|
||||
"--original_height",
|
||||
type=int,
|
||||
default=None,
|
||||
help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値"
|
||||
"--original_width",
|
||||
type=int,
|
||||
default=None,
|
||||
help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_height_negative",
|
||||
@@ -2604,8 +2896,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値",
|
||||
)
|
||||
parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値")
|
||||
parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値")
|
||||
parser.add_argument(
|
||||
"--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")
|
||||
parser.add_argument(
|
||||
"--vae_batch_size",
|
||||
@@ -2619,7 +2915,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨",
|
||||
)
|
||||
parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない")
|
||||
parser.add_argument(
|
||||
"--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない"
|
||||
)
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
@@ -2651,9 +2949,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale",
|
||||
)
|
||||
parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ")
|
||||
parser.add_argument(
|
||||
"--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ"
|
||||
"--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_cache_dir",
|
||||
@@ -2684,25 +2987,46 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する"
|
||||
"--opt_channels_last",
|
||||
action="store_true",
|
||||
help="set channels last option to model / モデルにchannels lastを指定し最適化する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名"
|
||||
"--network_module",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
|
||||
)
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
|
||||
parser.add_argument(
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||
"--network_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
||||
)
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
"--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models",
|
||||
type=int,
|
||||
default=None,
|
||||
help="merge this number of networks / この数だけネットワークをマージする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_pre_calc",
|
||||
action="store_true",
|
||||
help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_regional_mask_max_color_codes",
|
||||
@@ -2717,7 +3041,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
nargs="*",
|
||||
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
|
||||
)
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
|
||||
parser.add_argument(
|
||||
"--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_embeddings_multiples",
|
||||
type=int,
|
||||
@@ -2734,7 +3060,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
|
||||
"--highres_fix_steps",
|
||||
type=int,
|
||||
default=28,
|
||||
help="1st stage steps for highres fix / highres fixの最初のステージのステップ数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_strength",
|
||||
@@ -2743,7 +3072,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
|
||||
"--highres_fix_save_1st",
|
||||
action="store_true",
|
||||
help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_latents_upscaling",
|
||||
@@ -2751,7 +3082,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
|
||||
"--highres_fix_upscaler",
|
||||
type=str,
|
||||
default=None,
|
||||
help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_upscaler_args",
|
||||
@@ -2766,11 +3100,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
|
||||
"--negative_scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
|
||||
"--control_net_lllite_models",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="ControlNet models to use / 使用するControlNetのモデル名",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
|
||||
@@ -2819,6 +3160,45 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
|
||||
)
|
||||
|
||||
# gradual latent
|
||||
parser.add_argument(
|
||||
"--gradual_latent_timesteps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_ratio",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_ratio_step",
|
||||
type=float,
|
||||
default=0.125,
|
||||
help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_every_n_steps",
|
||||
type=int,
|
||||
default=3,
|
||||
help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_s_noise",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="s_noise for Gradual Latent / Gradual Latentのs_noise",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradual_latent_unsharp_params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
|
||||
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
||||
)
|
||||
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
@@ -2830,4 +3210,5 @@ if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
setup_logging(args, reset=True)
|
||||
main(args)
|
||||
|
||||
@@ -182,8 +182,21 @@ def call_unet_and_control_net(
|
||||
return original_unet(sample, timestep, encoder_hidden_states)
|
||||
|
||||
guided_hint = guided_hints[cnet_idx]
|
||||
|
||||
# gradual latent support: match the size of guided_hint to the size of sample
|
||||
if guided_hint.shape[-2:] != sample.shape[-2:]:
|
||||
# print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}")
|
||||
org_dtype = guided_hint.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
guided_hint = guided_hint.to(torch.float32)
|
||||
guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic")
|
||||
if org_dtype == torch.bfloat16:
|
||||
guided_hint = guided_hint.to(org_dtype)
|
||||
|
||||
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
||||
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
|
||||
outs = unet_forward(
|
||||
True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net
|
||||
)
|
||||
outs = [o * cnet_info.weight for o in outs]
|
||||
|
||||
# U-Net
|
||||
|
||||
Reference in New Issue
Block a user