support for controlnet in sample output

This commit is contained in:
ddPn08
2023-06-01 09:47:37 +09:00
parent 62d00b4520
commit 3bd00b88c2
4 changed files with 159 additions and 28 deletions

View File

@@ -1674,7 +1674,6 @@ class ControlNetDataset(BaseDataset):
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
cond_img = self.conditioning_image_transforms(cond_img)
conditioning_images.append(cond_img)
conditioning_images = torch.stack(conditioning_images)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
@@ -1699,7 +1698,7 @@ class ControlNetDataset(BaseDataset):
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
return example
@@ -3138,13 +3137,13 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2)
else:
# Diffusers model is loaded to CPU
print(f"load Diffusers pretrained models: {name_or_path}")
@@ -3172,14 +3171,14 @@ def transform_if_model_is_DDP(text_encoder, unet, network=None):
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
def load_target_model(args, weight_dtype, accelerator):
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
)
# work on low-ram device
@@ -3493,7 +3492,7 @@ SCHEDLER_SCHEDULE = "scaled_linear"
def sample_images(
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None
):
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
@@ -3609,6 +3608,7 @@ def sample_images(
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
@@ -3623,6 +3623,7 @@ def sample_images(
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
@@ -3655,6 +3656,12 @@ def sample_images(
negative_prompt = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
@@ -3668,6 +3675,10 @@ def sample_images(
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
@@ -3683,6 +3694,8 @@ def sample_images(
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
).images[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())