mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support for controlnet in sample output
This commit is contained in:
@@ -6,7 +6,7 @@ import re
|
|||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
@@ -426,6 +426,59 @@ def preprocess_mask(mask, scale_factor=8):
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_controlnet_image(
|
||||||
|
image: PIL.Image.Image,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
batch_size: int,
|
||||||
|
num_images_per_prompt: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
do_classifier_free_guidance: bool = False,
|
||||||
|
guess_mode: bool = False,
|
||||||
|
):
|
||||||
|
if not isinstance(image, torch.Tensor):
|
||||||
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
if isinstance(image[0], PIL.Image.Image):
|
||||||
|
images = []
|
||||||
|
|
||||||
|
for image_ in image:
|
||||||
|
image_ = image_.convert("RGB")
|
||||||
|
image_ = image_.resize(
|
||||||
|
(width, height), resample=PIL_INTERPOLATION["lanczos"]
|
||||||
|
)
|
||||||
|
image_ = np.array(image_)
|
||||||
|
image_ = image_[None, :]
|
||||||
|
images.append(image_)
|
||||||
|
|
||||||
|
image = images
|
||||||
|
|
||||||
|
image = np.concatenate(image, axis=0)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image.transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
elif isinstance(image[0], torch.Tensor):
|
||||||
|
image = torch.cat(image, dim=0)
|
||||||
|
|
||||||
|
image_batch_size = image.shape[0]
|
||||||
|
|
||||||
|
if image_batch_size == 1:
|
||||||
|
repeat_by = batch_size
|
||||||
|
else:
|
||||||
|
# image batch size is the same as prompt batch size
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
|
||||||
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
|
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance and not guess_mode:
|
||||||
|
image = torch.cat([image] * 2)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
||||||
@@ -707,6 +760,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
max_embeddings_multiples: Optional[int] = 3,
|
max_embeddings_multiples: Optional[int] = 3,
|
||||||
output_type: Optional[str] = "pil",
|
output_type: Optional[str] = "pil",
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
|
controlnet=None,
|
||||||
|
controlnet_image=None,
|
||||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||||
callback_steps: int = 1,
|
callback_steps: int = 1,
|
||||||
@@ -767,6 +822,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||||
plain tuple.
|
plain tuple.
|
||||||
|
controlnet (`diffusers.ControlNetModel`, *optional*):
|
||||||
|
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
||||||
|
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
||||||
|
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
||||||
|
inference.
|
||||||
callback (`Callable`, *optional*):
|
callback (`Callable`, *optional*):
|
||||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
@@ -785,6 +845,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
(nsfw) content, according to the `safety_checker`.
|
(nsfw) content, according to the `safety_checker`.
|
||||||
"""
|
"""
|
||||||
|
if controlnet is not None and controlnet_image is None:
|
||||||
|
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
||||||
|
|
||||||
# 0. Default height and width to unet
|
# 0. Default height and width to unet
|
||||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||||
@@ -824,6 +887,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
|
if controlnet_image is not None:
|
||||||
|
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)
|
||||||
|
|
||||||
|
|
||||||
# 5. set timesteps
|
# 5. set timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
||||||
@@ -851,8 +918,22 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
|
unet_additional_args = {}
|
||||||
|
if controlnet is not None:
|
||||||
|
down_block_res_samples, mid_block_res_sample = controlnet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=text_embeddings,
|
||||||
|
controlnet_cond=controlnet_image,
|
||||||
|
conditioning_scale=1.0,
|
||||||
|
guess_mode=False,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
|
||||||
|
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
|
|||||||
@@ -731,8 +731,7 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
|||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
def controlnet_conversion_map():
|
||||||
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
|
||||||
unet_conversion_map = [
|
unet_conversion_map = [
|
||||||
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
||||||
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
||||||
@@ -792,6 +791,12 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
|||||||
sd_prefix = f"zero_convs.{i}.0."
|
sd_prefix = f"zero_convs.{i}.0."
|
||||||
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
|
||||||
|
|
||||||
|
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
|
||||||
|
|
||||||
|
|
||||||
|
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
||||||
|
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
||||||
|
|
||||||
mapping = {k: k for k in controlnet_state_dict.keys()}
|
mapping = {k: k for k in controlnet_state_dict.keys()}
|
||||||
for sd_name, diffusers_name in unet_conversion_map:
|
for sd_name, diffusers_name in unet_conversion_map:
|
||||||
mapping[diffusers_name] = sd_name
|
mapping[diffusers_name] = sd_name
|
||||||
@@ -807,6 +812,23 @@ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
|
|||||||
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
|
||||||
|
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
|
||||||
|
|
||||||
|
mapping = {k: k for k in controlnet_state_dict.keys()}
|
||||||
|
for sd_name, diffusers_name in unet_conversion_map:
|
||||||
|
mapping[sd_name] = diffusers_name
|
||||||
|
for k, v in mapping.items():
|
||||||
|
for sd_part, diffusers_part in unet_conversion_map_layer:
|
||||||
|
v = v.replace(sd_part, diffusers_part)
|
||||||
|
mapping[k] = v
|
||||||
|
for k, v in mapping.items():
|
||||||
|
if "resnets" in v:
|
||||||
|
for sd_part, diffusers_part in unet_conversion_map_resnet:
|
||||||
|
v = v.replace(sd_part, diffusers_part)
|
||||||
|
mapping[k] = v
|
||||||
|
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
# ================#
|
# ================#
|
||||||
# VAE Conversion #
|
# VAE Conversion #
|
||||||
@@ -928,7 +950,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
|||||||
|
|
||||||
|
|
||||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
|
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
|
||||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
||||||
|
|
||||||
# Convert the UNet2DConditionModel model.
|
# Convert the UNet2DConditionModel model.
|
||||||
|
|||||||
@@ -1674,7 +1674,6 @@ class ControlNetDataset(BaseDataset):
|
|||||||
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
|
cond_img = self.trim_and_resize_if_required(subset, cond_img, image_info.bucket_reso, image_info.resized_size)
|
||||||
cond_img = self.conditioning_image_transforms(cond_img)
|
cond_img = self.conditioning_image_transforms(cond_img)
|
||||||
conditioning_images.append(cond_img)
|
conditioning_images.append(cond_img)
|
||||||
conditioning_images = torch.stack(conditioning_images)
|
|
||||||
|
|
||||||
example = {}
|
example = {}
|
||||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||||
@@ -1699,7 +1698,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
if self.debug_dataset:
|
if self.debug_dataset:
|
||||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||||
|
|
||||||
example["conditioning_images"] = conditioning_images.to(memory_format=torch.contiguous_format).float()
|
example["conditioning_images"] = torch.stack(conditioning_images).to(memory_format=torch.contiguous_format).float()
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
@@ -3138,13 +3137,13 @@ def prepare_dtype(args: argparse.Namespace):
|
|||||||
return weight_dtype, save_dtype
|
return weight_dtype, save_dtype
|
||||||
|
|
||||||
|
|
||||||
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", unet_use_linear_projection_in_v2=False):
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device, unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2)
|
||||||
else:
|
else:
|
||||||
# Diffusers model is loaded to CPU
|
# Diffusers model is loaded to CPU
|
||||||
print(f"load Diffusers pretrained models: {name_or_path}")
|
print(f"load Diffusers pretrained models: {name_or_path}")
|
||||||
@@ -3172,14 +3171,14 @@ def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
|||||||
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args, weight_dtype, accelerator):
|
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||||
# load models for each process
|
# load models for each process
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||||
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
args, weight_dtype, accelerator.device if args.lowram else "cpu", unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2
|
||||||
)
|
)
|
||||||
|
|
||||||
# work on low-ram device
|
# work on low-ram device
|
||||||
@@ -3493,7 +3492,7 @@ SCHEDLER_SCHEDULE = "scaled_linear"
|
|||||||
|
|
||||||
|
|
||||||
def sample_images(
|
def sample_images(
|
||||||
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None
|
accelerator, args: argparse.Namespace, epoch, steps, device, vae, tokenizer, text_encoder, unet, prompt_replacement=None, controlnet=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||||
@@ -3609,6 +3608,7 @@ def sample_images(
|
|||||||
height = prompt.get("height", 512)
|
height = prompt.get("height", 512)
|
||||||
scale = prompt.get("scale", 7.5)
|
scale = prompt.get("scale", 7.5)
|
||||||
seed = prompt.get("seed")
|
seed = prompt.get("seed")
|
||||||
|
controlnet_image = prompt.get("controlnet_image")
|
||||||
prompt = prompt.get("prompt")
|
prompt = prompt.get("prompt")
|
||||||
else:
|
else:
|
||||||
# prompt = prompt.strip()
|
# prompt = prompt.strip()
|
||||||
@@ -3623,6 +3623,7 @@ def sample_images(
|
|||||||
width = height = 512
|
width = height = 512
|
||||||
scale = 7.5
|
scale = 7.5
|
||||||
seed = None
|
seed = None
|
||||||
|
controlnet_image = None
|
||||||
for parg in prompt_args:
|
for parg in prompt_args:
|
||||||
try:
|
try:
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
@@ -3655,6 +3656,12 @@ def sample_images(
|
|||||||
negative_prompt = m.group(1)
|
negative_prompt = m.group(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||||
|
if m: # negative prompt
|
||||||
|
controlnet_image = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
print(ex)
|
||||||
@@ -3668,6 +3675,10 @@ def sample_images(
|
|||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
|
if controlnet_image is not None:
|
||||||
|
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||||
|
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||||
|
|
||||||
height = max(64, height - height % 8) # round to divisible by 8
|
height = max(64, height - height % 8) # round to divisible by 8
|
||||||
width = max(64, width - width % 8) # round to divisible by 8
|
width = max(64, width - width % 8) # round to divisible by 8
|
||||||
print(f"prompt: {prompt}")
|
print(f"prompt: {prompt}")
|
||||||
@@ -3683,6 +3694,8 @@ def sample_images(
|
|||||||
num_inference_steps=sample_steps,
|
num_inference_steps=sample_steps,
|
||||||
guidance_scale=scale,
|
guidance_scale=scale,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
|
controlnet=controlnet,
|
||||||
|
controlnet_image=controlnet_image,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -11,6 +12,7 @@ import torch
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler, ControlNetModel
|
from diffusers import DDPMScheduler, ControlNetModel
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
@@ -26,9 +28,6 @@ from library.custom_train_functions import (
|
|||||||
pyramid_noise_like,
|
pyramid_noise_like,
|
||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
)
|
)
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
|
||||||
download_controlnet_from_original_ckpt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
@@ -124,19 +123,24 @@ def train(args):
|
|||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(
|
text_encoder, vae, unet, _ = train_util.load_target_model(
|
||||||
args, weight_dtype, accelerator
|
args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
controlnet = ControlNetModel.from_unet(unet)
|
||||||
|
|
||||||
if args.controlnet_model_name_or_path:
|
if args.controlnet_model_name_or_path:
|
||||||
if os.path.isfile(args.controlnet_model_name_or_path):
|
filename = args.controlnet_model_name_or_path
|
||||||
controlnet = download_controlnet_from_original_ckpt(
|
if os.path.isfile(filename):
|
||||||
args.controlnet_model_name_or_path
|
if os.path.splitext(filename)[1] == ".safetensors":
|
||||||
)
|
state_dict = load_file(filename)
|
||||||
else:
|
else:
|
||||||
controlnet = ControlNetModel.from_pretrained(
|
state_dict = torch.load(filename)
|
||||||
args.controlnet_model_name_or_path
|
state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
|
||||||
)
|
controlnet.load_state_dict(state_dict)
|
||||||
else:
|
elif os.path.isdir(filename):
|
||||||
controlnet = ControlNetModel.from_unet(unet)
|
controlnet = ControlNetModel.from_pretrained(filename)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
@@ -289,7 +293,9 @@ def train(args):
|
|||||||
)
|
)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
accelerator.init_trackers(
|
accelerator.init_trackers(
|
||||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name
|
"controlnet_train"
|
||||||
|
if args.log_tracker_name is None
|
||||||
|
else args.log_tracker_name
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_list = []
|
loss_list = []
|
||||||
@@ -350,7 +356,7 @@ def train(args):
|
|||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
encoder_hidden_states = train_util.gethidden_states(
|
encoder_hidden_states = train_util.get_hidden_states(
|
||||||
args, input_ids, tokenizer, text_encoder, weight_dtype
|
args, input_ids, tokenizer, text_encoder, weight_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -450,6 +456,7 @@ def train(args):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
unet,
|
unet,
|
||||||
|
controlnet=controlnet,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
# 指定ステップごとにモデルを保存
|
||||||
@@ -537,6 +544,7 @@ def train(args):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
unet,
|
unet,
|
||||||
|
controlnet=controlnet,
|
||||||
)
|
)
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
@@ -569,6 +577,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_model_as",
|
||||||
|
type=str,
|
||||||
|
default="safetensors",
|
||||||
|
choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
|
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--controlnet_model_name_or_path",
|
"--controlnet_model_name_or_path",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user