diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 84e1ab15..88317e30 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -6,7 +6,7 @@ import re from typing import Callable, List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -426,6 +426,59 @@ def preprocess_mask(mask, scale_factor=8): return mask +def prepare_controlnet_image( + image: PIL.Image.Image, + width: int, + height: int, + batch_size: int, + num_images_per_prompt: int, + device: torch.device, + dtype: torch.dtype, + do_classifier_free_guidance: bool = False, + guess_mode: bool = False, +): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize( + (width, height), resample=PIL_INTERPOLATION["lanczos"] + ) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing @@ -707,6 +760,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, + controlnet=None, + controlnet_image=None, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, @@ -767,6 +822,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. + controlnet (`diffusers.ControlNetModel`, *optional*): + A controlnet model to be used for the inference. If not provided, controlnet will be disabled. + controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet + inference. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. @@ -785,6 +845,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + if controlnet is not None and controlnet_image is None: + raise ValueError("controlnet_image must be provided if controlnet is not None.") + # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -824,6 +887,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): else: mask = None + if controlnet_image is not None: + controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False) + + # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None) @@ -851,8 +918,22 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + unet_additional_args = {} + if controlnet is not None: + down_block_res_samples, mid_block_res_sample = controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=controlnet_image, + conditioning_scale=1.0, + guess_mode=False, + return_dict=False, + ) + unet_additional_args['down_block_additional_residuals'] = down_block_res_samples + unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample + # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample # perform guidance if do_classifier_free_guidance: diff --git a/library/model_util.py b/library/model_util.py index bb168653..0764a881 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -731,8 +731,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict): return new_state_dict - -def convert_controlnet_state_dict_to_sd(controlnet_state_dict): +def controlnet_conversion_map(): unet_conversion_map = [ ("time_embed.0.weight", "time_embedding.linear_1.weight"), ("time_embed.0.bias", "time_embedding.linear_1.bias"), @@ -792,6 +791,12 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict): sd_prefix = f"zero_convs.{i}.0." unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer + + +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + mapping = {k: k for k in controlnet_state_dict.keys()} for sd_name, diffusers_name in unet_conversion_map: mapping[diffusers_name] = sd_name @@ -807,6 +812,23 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict): new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} return new_state_dict +def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[sd_name] = diffusers_name + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + for k, v in mapping.items(): + if "resnets" in v: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict # ================# # VAE Conversion # @@ -928,7 +950,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False): +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. diff --git a/library/train_util.py b/library/train_util.py index 1921c2a4..81dffb1d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1674,7 +1674,6 @@ class ControlNetDataset(BaseDataset): cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size) cond_img = self.conditioning_image_transforms(cond_img) conditioning_images.append(cond_img) - conditioning_images = torch.stack(conditioning_images) example = {} example["loss_weights"] = torch.FloatTensor(loss_weights) @@ -1699,7 +1698,7 @@ class ControlNetDataset(BaseDataset): if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] - example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float() + example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float() return example @@ -3138,13 +3137,13 @@ def prepare_dtype(args: argparse.Namespace): return weight_dtype, save_dtype -def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): +def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False): name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers if load_stable_diffusion_format: print(f"load StableDiffusion checkpoint: {name_or_path}") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device) + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2) else: # Diffusers model is loaded to CPU print(f"load Diffusers pretrained models: {name_or_path}") @@ -3172,14 +3171,14 @@ def transform_if_model_is_DDP(text_encoder, unet, network=None): return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) -def load_target_model(args, weight_dtype, accelerator): +def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): # load models for each process for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( - args, weight_dtype, accelerator.device if args.lowram else "cpu" + args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2 ) # work on low-ram device @@ -3493,7 +3492,7 @@ SCHEDLER_SCHEDULE = "scaled_linear" def sample_images( - accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None + accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None ): """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した @@ -3609,6 +3608,7 @@ def sample_images( height = prompt.get("height", 512) scale = prompt.get("scale", 7.5) seed = prompt.get("seed") + controlnet_image = prompt.get("controlnet_image") prompt = prompt.get("prompt") else: # prompt = prompt.strip() @@ -3623,6 +3623,7 @@ def sample_images( width = height = 512 scale = 7.5 seed = None + controlnet_image = None for parg in prompt_args: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) @@ -3655,6 +3656,12 @@ def sample_images( negative_prompt = m.group(1) continue + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: # negative prompt + controlnet_image = m.group(1) + continue + + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -3668,6 +3675,10 @@ def sample_images( if negative_prompt is not None: negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + height = max(64, height - height % 8) # round to divisible by 8 width = max(64, width - width % 8) # round to divisible by 8 print(f"prompt: {prompt}") @@ -3683,6 +3694,8 @@ def sample_images( num_inference_steps=sample_steps, guidance_scale=scale, negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, ).images[0] ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) diff --git a/train_controlnet.py b/train_controlnet.py index 7bcaf03a..263e8813 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,6 @@ import argparse import gc +import json import math import os import random @@ -11,6 +12,7 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file import library.model_util as model_util import library.train_util as train_util @@ -26,9 +28,6 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, ) -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - download_controlnet_from_original_ckpt, -) # TODO 他のスクリプトと共通化する @@ -124,19 +123,24 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator + args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True ) + + controlnet = ControlNetModel.from_unet(unet) + if args.controlnet_model_name_or_path: - if os.path.isfile(args.controlnet_model_name_or_path): - controlnet = download_controlnet_from_original_ckpt( - args.controlnet_model_name_or_path - ) - else: - controlnet = ControlNetModel.from_pretrained( - args.controlnet_model_name_or_path - ) - else: - controlnet = ControlNetModel.from_unet(unet) + filename = args.controlnet_model_name_or_path + if os.path.isfile(filename): + if os.path.splitext(filename)[1] == ".safetensors": + state_dict = load_file(filename) + else: + state_dict = torch.load(filename) + state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict) + controlnet.load_state_dict(state_dict) + elif os.path.isdir(filename): + controlnet = ControlNetModel.from_pretrained(filename) + + # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -289,7 +293,9 @@ def train(args): ) if accelerator.is_main_process: accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name + "controlnet_train" + if args.log_tracker_name is None + else args.log_tracker_name ) loss_list = [] @@ -350,7 +356,7 @@ def train(args): b_size = latents.shape[0] input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.gethidden_states( + encoder_hidden_states = train_util.get_hidden_states( args, input_ids, tokenizer, text_encoder, weight_dtype ) @@ -450,6 +456,7 @@ def train(args): tokenizer, text_encoder, unet, + controlnet=controlnet, ) # 指定ステップごとにモデルを保存 @@ -537,6 +544,7 @@ def train(args): tokenizer, text_encoder, unet, + controlnet=controlnet, ) # end of epoch @@ -569,6 +577,13 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) parser.add_argument( "--controlnet_model_name_or_path", type=str,