supprt deep shink with regional lora, add prompter module

This commit is contained in:
Kohya S
2024-02-12 14:17:27 +09:00
parent d3745db764
commit cbe9c5dc06
2 changed files with 161 additions and 45 deletions

View File

@@ -3,6 +3,8 @@ import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob import glob
import importlib import importlib
import importlib.util
import sys
import inspect import inspect
import time import time
import zipfile import zipfile
@@ -333,6 +335,10 @@ class PipelineLike:
self.scheduler = scheduler self.scheduler = scheduler
self.safety_checker = None self.safety_checker = None
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion # Textual Inversion
self.token_replacements_list = [] self.token_replacements_list = []
for _ in range(len(self.text_encoders)): for _ in range(len(self.text_encoders)):
@@ -419,6 +425,7 @@ class PipelineLike:
callback_steps: Optional[int] = 1, callback_steps: Optional[int] = 1,
img2img_noise=None, img2img_noise=None,
clip_guide_images=None, clip_guide_images=None,
emb_normalize_mode: str = "original",
**kwargs, **kwargs,
): ):
# TODO support secondary prompt # TODO support secondary prompt
@@ -493,6 +500,7 @@ class PipelineLike:
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
token_replacer=token_replacer, token_replacer=token_replacer,
device=self.device, device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs, **kwargs,
) )
tes_text_embs.append(text_embeddings) tes_text_embs.append(text_embeddings)
@@ -508,6 +516,7 @@ class PipelineLike:
clip_skip=self.clip_skip, clip_skip=self.clip_skip,
token_replacer=token_replacer, token_replacer=token_replacer,
device=self.device, device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs, **kwargs,
) )
tes_real_uncond_embs.append(real_uncond_embeddings) tes_real_uncond_embs.append(real_uncond_embeddings)
@@ -1099,7 +1108,7 @@ def get_unweighted_text_embeddings(
# in sdxl, value of clip_skip is same for Text Encoder 1 and 2 # in sdxl, value of clip_skip is same for Text Encoder 1 and 2
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip] text_embedding = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
if pool is None: if pool is None:
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
@@ -1122,7 +1131,7 @@ def get_unweighted_text_embeddings(
else: else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip] text_embeddings = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
if pool is not None: if pool is not None:
@@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings(
clip_skip: int = 1, clip_skip: int = 1,
token_replacer=None, token_replacer=None,
device=None, device=None,
emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
**kwargs, **kwargs,
): ):
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -1239,16 +1249,34 @@ def get_weighted_text_embeddings(
# assign weights to the prompts and normalize in the sense of mean # assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)? # TODO: should we normalize by chunk or in a whole (current implementation)?
# →全体でいいんじゃないかな # →全体でいいんじゃないかな
if (not skip_parsing) and (not skip_weighting): if (not skip_parsing) and (not skip_weighting):
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) if emb_normalize_mode == "abs":
text_embeddings *= prompt_weights.unsqueeze(-1) previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) text_embeddings *= prompt_weights.unsqueeze(-1)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
if uncond_prompt is not None: text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) if uncond_prompt is not None:
uncond_embeddings *= uncond_weights.unsqueeze(-1) previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) uncond_embeddings *= uncond_weights.unsqueeze(-1)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
elif emb_normalize_mode == "none":
text_embeddings *= prompt_weights.unsqueeze(-1)
if uncond_prompt is not None:
uncond_embeddings *= uncond_weights.unsqueeze(-1)
else: # "original"
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None: if uncond_prompt is not None:
return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens
@@ -1427,6 +1455,27 @@ class BatchData(NamedTuple):
ext: BatchDataExt ext: BatchDataExt
class ListPrompter:
def __init__(self, prompts: List[str]):
self.prompts = prompts
self.index = 0
def shuffle(self):
random.shuffle(self.prompts)
def __len__(self):
return len(self.prompts)
def __call__(self, *args, **kwargs):
if self.index >= len(self.prompts):
self.index = 0 # reset
return None
prompt = self.prompts[self.index]
self.index += 1
return prompt
def main(args): def main(args):
if args.fp16: if args.fp16:
dtype = torch.float16 dtype = torch.float16
@@ -1951,15 +2000,35 @@ def main(args):
token_embeds2[token_id] = embed token_embeds2[token_id] = embed
# promptを取得する # promptを取得する
prompt_list = None
if args.from_file is not None: if args.from_file is not None:
print(f"reading prompts from {args.from_file}") print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f: with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines() prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
prompter = ListPrompter(prompt_list)
elif args.from_module is not None:
def load_module_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
print(f"reading prompts from module: {args.from_module}")
prompt_module = load_module_from_path("prompt_module", args.from_module)
prompter = prompt_module.get_prompter(args, pipe, networks)
elif args.prompt is not None: elif args.prompt is not None:
prompt_list = [args.prompt] prompter = ListPrompter([args.prompt])
else: else:
prompt_list = [] prompter = None # interactive mode
if args.interactive: if args.interactive:
args.n_iter = 1 args.n_iter = 1
@@ -2026,14 +2095,16 @@ def main(args):
mask_images = None mask_images = None
# promptがないとき、画像のPngInfoから取得する # promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive: if init_images is not None and prompter is None and not args.interactive:
print("get prompts from images' metadata") print("get prompts from images' metadata")
prompt_list = []
for img in init_images: for img in init_images:
if "prompt" in img.text: if "prompt" in img.text:
prompt = img.text["prompt"] prompt = img.text["prompt"]
if "negative-prompt" in img.text: if "negative-prompt" in img.text:
prompt += " --n " + img.text["negative-prompt"] prompt += " --n " + img.text["negative-prompt"]
prompt_list.append(prompt) prompt_list.append(prompt)
prompter = ListPrompter(prompt_list)
# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
l = [] l = []
@@ -2105,15 +2176,18 @@ def main(args):
else: else:
guide_images = None guide_images = None
# seed指定時はseedを決めておく # 新しい乱数生成器を作成する
if args.seed is not None: if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1:
random.seed(args.seed) # 引数のseedをそのまま使う
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] def fixed_seed(*args, **kwargs):
if len(predefined_seeds) == 1: return args.seed
predefined_seeds[0] = args.seed
seed_random = SimpleNamespace(randint=fixed_seed)
else:
seed_random = random.Random(args.seed)
else: else:
predefined_seeds = None seed_random = random.Random()
# デフォルト画像サイズを設定するimg2imgではこれらの値は無視されるまたはW*Hにリサイズ済み # デフォルト画像サイズを設定するimg2imgではこれらの値は無視されるまたはW*Hにリサイズ済み
if args.W is None: if args.W is None:
@@ -2127,11 +2201,14 @@ def main(args):
for gen_iter in range(args.n_iter): for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}") print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF) if args.iter_same_seed:
iter_seed = seed_random.randint(0, 2**32 - 1)
else:
iter_seed = None
# shuffle prompt list # shuffle prompt list
if args.shuffle_prompts: if args.shuffle_prompts:
random.shuffle(prompt_list) prompter.shuffle()
# バッチ処理の関数 # バッチ処理の関数
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
@@ -2352,7 +2429,8 @@ def main(args):
for n, m in zip(networks, network_muls if network_muls else network_default_muls): for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m) n.set_multiplier(m)
if regional_network: if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) # TODO バッチから ds_ratio を取り出すべき
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio)
if not regional_network and network_pre_calc: if not regional_network and network_pre_calc:
for n in networks: for n in networks:
@@ -2386,6 +2464,7 @@ def main(args):
return_latents=return_latents, return_latents=return_latents,
clip_prompts=clip_prompts, clip_prompts=clip_prompts,
clip_guide_images=guide_images, clip_guide_images=guide_images,
emb_normalize_mode=args.emb_normalize_mode,
) )
if highres_1st and not args.highres_fix_save_1st: # return images or latents if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images return images
@@ -2451,8 +2530,8 @@ def main(args):
prompt_index = 0 prompt_index = 0
global_step = 0 global_step = 0
batch_data = [] batch_data = []
while args.interactive or prompt_index < len(prompt_list): while True:
if len(prompt_list) == 0: if args.interactive:
# interactive # interactive
valid = False valid = False
while not valid: while not valid:
@@ -2466,7 +2545,9 @@ def main(args):
if not valid: # EOF, end app if not valid: # EOF, end app
break break
else: else:
raw_prompt = prompt_list[prompt_index] raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step)
if raw_prompt is None:
break
# sd-dynamic-prompts like variants: # sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
@@ -2513,7 +2594,8 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --") prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0] prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") length = len(prompter) if hasattr(prompter, "__len__") else 0
print(f"prompt {prompt_index+1}/{length}: {prompt}")
for parg in prompt_args[1:]: for parg in prompt_args[1:]:
try: try:
@@ -2731,23 +2813,17 @@ def main(args):
# prepare seed # prepare seed
if seeds is not None: # given in prompt if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
if len(seeds) > 0: if len(seeds) > 0:
seed = seeds.pop(0) seed = seeds.pop(0)
else: else:
if predefined_seeds is not None: if args.iter_same_seed:
if len(predefined_seeds) > 0: seed = iter_seed
seed = predefined_seeds.pop(0)
else:
print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
else: else:
seed = None # 前のを消す seed = None # 前のを消す
if seed is None: if seed is None:
seed = random.randint(0, 0x7FFFFFFF) seed = seed_random.randint(0, 2**32 - 1)
if args.interactive: if args.interactive:
print(f"seed: {seed}") print(f"seed: {seed}")
@@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
) )
parser.add_argument(
"--from_module",
type=str,
default=None,
help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む",
)
parser.add_argument(
"--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数"
)
parser.add_argument( parser.add_argument(
"--interactive", "--interactive",
action="store_true", action="store_true",
@@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
) )
parser.add_argument(
"--emb_normalize_mode",
type=str,
default="original",
choices=["original", "none", "abs"],
help="embedding normalization mode / embeddingの正規化モード",
)
parser.add_argument( parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
) )

View File

@@ -12,8 +12,10 @@ import numpy as np
import torch import torch
import re import re
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -248,7 +250,8 @@ class LoRAInfModule(LoRAModule):
if mask is None: if mask is None:
# raise ValueError(f"mask is None for resolution {area}") # raise ValueError(f"mask is None for resolution {area}")
# emb_layers in SDXL doesn't have mask # emb_layers in SDXL doesn't have mask
# logger.info(f"mask is None for resolution {area}, {x.size()}") # if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4: if len(x.size()) != 4:
@@ -265,7 +268,9 @@ class LoRAInfModule(LoRAModule):
# apply mask for LoRA result # apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx) mask = self.get_mask_for_x(lx)
# logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}") # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
# mask = mask.squeeze(-1)
lx = lx * mask lx = lx * mask
x = self.org_forward(x) x = self.org_forward(x)
@@ -514,7 +519,9 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else: else:
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks block_dims = [network_dim] * num_total_blocks
if block_alphas is not None: if block_alphas is not None:
@@ -792,7 +799,9 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"create LoRA network from weights") logger.info(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:
logger.info(f"create LoRA network from block_dims") logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}") logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}") logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None: if conv_block_dims is not None:
@@ -800,9 +809,13 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"conv_block_alphas: {conv_block_alphas}") logger.info(f"conv_block_alphas: {conv_block_alphas}")
else: else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances # create module instances
def create_modules( def create_modules(
@@ -929,6 +942,10 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -1116,7 +1133,7 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self) lora.set_network(self)
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
self.batch_size = batch_size self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width) self.current_size = (height, width)
@@ -1142,6 +1159,13 @@ class LoRANetwork(torch.nn.Module):
resize_add(h, w) resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2) resize_add(h + h % 2, w + w % 2)
# deep shrink
if ds_ratio is not None:
hd = int(h * ds_ratio)
wd = int(w * ds_ratio)
resize_add(hd, wd)
h = (h + 1) // 2 h = (h + 1) // 2
w = (w + 1) // 2 w = (w + 1) // 2