use original ControlNet instead of Diffusers

This commit is contained in:
Kohya S
2024-09-29 23:07:34 +09:00
parent e0c3630203
commit 8919b31145
5 changed files with 526 additions and 237 deletions

View File

@@ -43,8 +43,8 @@ from diffusers import (
)
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
from accelerate import init_empty_weights
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@@ -58,6 +58,7 @@ import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.sdxl_original_control_net import SdxlControlNet
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
@@ -352,8 +353,8 @@ class PipelineLike:
self.token_replacements_list.append({})
# ControlNet
self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5
self.control_net_lllites: List[ControlNetLLLite] = []
self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
self.gradual_latent: GradualLatent = None
@@ -542,7 +543,7 @@ class PipelineLike:
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
if self.control_net_lllites:
if self.control_net_lllites or (self.control_nets and self.is_sdxl):
# ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
@@ -731,7 +732,12 @@ class PipelineLike:
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
if not self.is_sdxl:
guided_hints = original_control_net.get_guided_hints(
self.control_nets, num_latent_input, batch_size, clip_guide_images
)
else:
clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1]
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
if self.control_net_lllites:
@@ -793,7 +799,7 @@ class PipelineLike:
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo
# disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo
if self.control_net_lllites:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
@@ -802,9 +808,16 @@ class PipelineLike:
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
if self.control_nets and self.is_sdxl:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
each_control_net_enabled[j] = False
# predict the noise residual
if self.control_nets and self.control_net_enabled:
if self.control_nets and self.control_net_enabled and not self.is_sdxl:
if regional_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
@@ -823,6 +836,31 @@ class PipelineLike:
text_embeddings,
text_emb_last,
).sample
elif self.control_nets:
input_resi_add_list = []
mid_add_list = []
for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled):
if not enbld:
continue
input_resi_add, mid_add = control_net(
latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images
)
input_resi_add_list.append(input_resi_add)
mid_add_list.append(mid_add)
if len(input_resi_add_list) == 0:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
if len(input_resi_add_list) > 1:
# get mean of input_resi_add_list and mid_add_list
input_resi_add_mean = []
for i in range(len(input_resi_add_list[0])):
input_resi_add_mean.append(
torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0))
)
input_resi_add = input_resi_add_mean
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
elif self.is_sdxl:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
@@ -1827,16 +1865,37 @@ def main(args):
upscaler.to(dtype).to(device)
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
if not is_sdxl:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
else:
for i, model_file in enumerate(args.control_net_models):
multiplier = (
1.0
if not args.control_net_multipliers or len(args.control_net_multipliers) <= i
else args.control_net_multipliers[i]
)
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
logger.info(f"loading SDXL ControlNet: {model_file}")
from safetensors.torch import load_file
state_dict = load_file(model_file)
logger.info(f"Initalizing SDXL ControlNet with multiplier: {multiplier}")
with init_empty_weights():
control_net = SdxlControlNet(multiplier=multiplier)
control_net.load_state_dict(state_dict)
control_net.to(dtype).to(device)
control_nets.append((control_net, ratio))
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:

View File

@@ -8,7 +8,7 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
from .utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging

View File

@@ -0,0 +1,272 @@
# some parts are modified from Diffusers library (Apache License 2.0)
import math
from types import SimpleNamespace
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sdxl_original_unet
from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl
class ControlNetConditioningEmbedding(nn.Module):
def __init__(self):
super().__init__()
dims = [16, 32, 96, 256]
self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(dims) - 1):
channel_in = dims[i]
channel_out = dims[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1)
nn.init.zeros_(self.conv_out.weight) # zero module weight
nn.init.zeros_(self.conv_out.bias) # zero module bias
def forward(self, x):
x = self.conv_in(x)
x = F.silu(x)
for block in self.blocks:
x = block(x)
x = F.silu(x)
x = self.conv_out(x)
return x
class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel):
def __init__(self, multiplier: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
self.multiplier = multiplier
# remove unet layers
self.output_blocks = nn.ModuleList([])
del self.out
self.controlnet_cond_embedding = ControlNetConditioningEmbedding()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280]
self.controlnet_down_blocks = nn.ModuleList([])
for dim in dims:
self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1))
nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight
nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias
self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1)
nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight
nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias
def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel):
unet_sd = unet.state_dict()
unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")}
sd = super().state_dict()
sd.update(unet_sd)
info = super().load_state_dict(sd, strict=True, assign=True)
return info
def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any:
# convert state_dict to SAI format
unet_sd = {}
for k in list(state_dict.keys()):
if not k.startswith("controlnet_"):
unet_sd[k] = state_dict.pop(k)
unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd)
state_dict.update(unet_sd)
super().load_state_dict(state_dict, strict=strict, assign=assign)
def state_dict(self, destination=None, prefix="", keep_vars=False):
# convert state_dict to Diffusers format
state_dict = super().state_dict(destination, prefix, keep_vars)
control_net_sd = {}
for k in list(state_dict.keys()):
if k.startswith("controlnet_"):
control_net_sd[k] = state_dict.pop(k)
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
state_dict.update(control_net_sd)
return state_dict
def forward(
self,
x: torch.Tensor,
timesteps: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
cond_image: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
multiplier = self.multiplier if self.multiplier is not None else 1.0
hs = []
for i, module in enumerate(self.input_blocks):
h = call_module(module, h, emb, context)
if i == 0:
h = self.controlnet_cond_embedding(cond_image) + h
hs.append(self.controlnet_down_blocks[i](h) * multiplier)
h = call_module(self.middle_block, h, emb, context)
h = self.controlnet_mid_block(h) * multiplier
return hs, h
class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel):
"""
This class is for training purpose only.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
h = h + mid_add
for module in self.output_blocks:
resi = hs.pop() + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
if __name__ == "__main__":
import time
logger.info("create unet")
unet = SdxlControlledUNet()
unet.to("cuda", torch.bfloat16)
unet.set_use_sdpa(True)
unet.set_gradient_checkpointing(True)
unet.train()
logger.info("create control_net")
control_net = SdxlControlNet()
control_net.to("cuda")
control_net.set_use_sdpa(True)
control_net.set_gradient_checkpointing(True)
control_net.train()
logger.info("Initialize control_net from unet")
control_net.init_from_unet(unet)
unet.requires_grad_(False)
control_net.requires_grad_(True)
# 使用メモリ量確認用の疑似学習ループ
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# import transformers
# optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
scaler = torch.cuda.amp.GradScaler(enabled=True)
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda")
txt = torch.randn(batch_size, 77, 2048).cuda()
vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda()
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img)
output = unet(x, t, txt, vector, input_resi_add, mid_add)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info("finish training")
sd = control_net.state_dict()
from safetensors.torch import save_file
save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors")

View File

@@ -30,7 +30,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging
@@ -1156,9 +1156,9 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
"""
_self = self.delegate
@@ -1209,6 +1209,8 @@ class InferSdxlUNet2DConditionModel:
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
if mid_add is not None:
h = h + mid_add
for module in _self.output_blocks:
# Deep Shrink
@@ -1217,7 +1219,11 @@ class InferSdxlUNet2DConditionModel:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
h = torch.cat([h, hs.pop()], dim=1)
resi = hs.pop()
if input_resi_add is not None:
resi = resi + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0

View File

@@ -14,6 +14,7 @@ init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from accelerate import init_empty_weights
from diffusers import DDPMScheduler, ControlNetModel
from diffusers.utils.torch_utils import is_compiled_module
from safetensors.torch import load_file
@@ -23,6 +24,9 @@ from library import (
sdxl_model_util,
sdxl_original_unet,
sdxl_train_util,
strategy_base,
strategy_sd,
strategy_sdxl,
)
import library.model_util as model_util
@@ -41,6 +45,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet
from library.utils import setup_logging, add_logging_arguments
setup_logging()
@@ -58,10 +63,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"]
* lr_scheduler.optimizers[-1].param_groups[0]["lr"]
)
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
@@ -79,7 +81,14 @@ def train(args):
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
@@ -106,17 +115,18 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = (train_dataset_group if args.max_data_loader_n_workers == 0 else None)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
@@ -162,86 +172,99 @@ def train(args):
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
)
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
# convert U-Net
with torch.no_grad():
du_unet_sd = sdxl_model_util.convert_sdxl_unet_state_dict_to_diffusers(unet.state_dict())
unet.to("cpu")
clean_memory_on_device(accelerator.device)
del unet
unet = sdxl_model_util.UNet2DConditionModel(**sdxl_model_util.DIFFUSERS_SDXL_UNET_CONFIG)
unet.load_state_dict(du_unet_sd)
unet.to(accelerator.device) # reduce main memory usage
controlnet = ControlNetModel.from_unet(unet)
# convert U-Net to Controlled U-Net
logger.info("convert U-Net to Controlled U-Net")
unet_sd = unet.state_dict()
with init_empty_weights():
unet = SdxlControlledUNet()
unet.load_state_dict(unet_sd, strict=True, assign=True)
del unet_sd
if args.controlnet_model_name_or_path:
filename = args.controlnet_model_name_or_path
if os.path.isfile(filename):
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
state_dict = torch.load(filename)
state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
controlnet.load_state_dict(state_dict)
elif os.path.isdir(filename):
controlnet = ControlNetModel.from_pretrained(filename)
# make control net
logger.info("make ControlNet")
if args.controlnet_model_path:
with init_empty_weights():
control_net = SdxlControlNet()
logger.info(f"load ControlNet from {args.controlnet_model_path}")
filename = args.controlnet_model_path
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
state_dict = torch.load(filename)
info = control_net.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"ControlNet loaded from {filename}: {info}")
else:
control_net = SdxlControlNet()
logger.info("initialize ControlNet from U-Net")
info = control_net.init_from_unet(unet)
logger.info(f"ControlNet initialized from U-Net: {info}")
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
train_dataset_group.new_cache_latents(vae, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
accelerator.wait_for_everyone()
# モデルに xformers とか memory efficient attention を組み込む
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if args.xformers:
unet.enable_xformers_memory_efficient_attention()
controlnet.enable_xformers_memory_efficient_attention()
unet.set_use_memory_efficient_attention(True, False)
control_net.set_use_memory_efficient_attention(True, False)
elif args.sdpa:
unet.set_use_sdpa(True)
control_net.set_use_sdpa(True)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
controlnet.enable_gradient_checkpointing()
control_net.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(filter(lambda p: p.requires_grad, controlnet.parameters()))
trainable_params = list(control_net.parameters())
# for p in trainable_params:
# p.requires_grad = True
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(
f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}"
)
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -257,7 +280,7 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader)/ accelerator.num_processes/ args.gradient_accumulation_steps
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
@@ -267,9 +290,7 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(
args, optimizer, accelerator.num_processes
)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
@@ -277,19 +298,17 @@ def train(args):
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
controlnet.to(weight_dtype)
unet.to(weight_dtype)
control_net.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
controlnet.to(weight_dtype)
unet.to(weight_dtype)
control_net.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
control_net, optimizer, train_dataloader, lr_scheduler
)
if args.fused_backward_pass:
@@ -314,10 +333,8 @@ def train(args):
text_encoder2.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
# transform DDP after prepare
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
controlnet.train()
unet.eval()
control_net.train()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -362,26 +379,15 @@ def train(args):
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(
noise_scheduler
)
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
@@ -390,11 +396,7 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
(
"controlnet_train"
if args.log_tracker_name is None
else args.log_tracker_name
),
("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name),
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
@@ -409,10 +411,8 @@ def train(args):
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = (
sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet"
)
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet"
state_dict = model.state_dict()
if save_dtype is not None:
for key in list(state_dict.keys()):
@@ -436,19 +436,19 @@ def train(args):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
controlnet=controlnet,
)
# # For --sample_at_first
# sdxl_train_util.sample_images(
# accelerator,
# args,
# 0,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# unet,
# controlnet=control_net,
# )
# training loop
for epoch in range(num_train_epochs):
@@ -457,121 +457,63 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(controlnet):
with accelerator.accumulate(control_net):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = (
batch["latents"]
.to(accelerator.device)
.to(dtype=weight_dtype)
)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = (
vae.encode(batch["images"].to(dtype=vae_dtype))
.latent_dist.sample()
.to(dtype=weight_dtype)
)
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print(
"NaN found in latents, replacing with zeros"
)
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if (
"text_encoder_outputs1_list" not in batch
or batch["text_encoder_outputs1_list"] is None
):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.no_grad():
# Get the text embedding for conditioning
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = (
train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
)
else:
encoder_hidden_states1 = (
batch["text_encoder_outputs1_list"]
.to(accelerator.device)
.to(weight_dtype)
)
encoder_hidden_states2 = (
batch["text_encoder_outputs2_list"]
.to(accelerator.device)
.to(weight_dtype)
)
pool2 = (
batch["text_encoder_pool2_list"]
.to(accelerator.device)
.to(weight_dtype)
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(
# orig_size, crop_size, target_size, accelerator.device
# ).to(weight_dtype)
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
embs = torch.cat([orig_size, crop_size, target_size]).to(accelerator.device).to(weight_dtype) #B,6
# concat embeddings
#vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
vector_embedding_dict = {
"text_embeds": pool2,
"time_ids": embs
}
text_embedding = torch.cat(
[encoder_hidden_states1, encoder_hidden_states2], dim=2
).to(weight_dtype)
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = (
train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast():
down_block_res_samples, mid_block_res_sample = controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=text_embedding,
added_cond_kwargs=vector_embedding_dict,
controlnet_cond=controlnet_image,
return_dict=False,
input_resi_add, mid_add = control_net(
noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image
)
# Predict the noise residual
noise_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=text_embedding,
added_cond_kwargs=vector_embedding_dict,
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
return_dict=False,
)[0]
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add)
if args.v_parameterization:
# v-parameterization training
@@ -580,7 +522,7 @@ def train(args):
target = noise
loss = train_util.conditional_loss(
noise_pred.float(),target.float(),reduction="none",loss_type=args.loss_type,huber_c=huber_c,
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
@@ -588,7 +530,7 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss,timesteps,noise_scheduler,args.min_snr_gamma,args.v_parameterization)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
@@ -601,7 +543,7 @@ def train(args):
accelerator.backward(loss)
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = controlnet.parameters()
params_to_clip = control_net.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -616,25 +558,25 @@ def train(args):
progress_bar.update(1)
global_step += 1
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
controlnet=controlnet,
)
# sdxl_train_util.sample_images(
# accelerator,
# args,
# None,
# global_step,
# accelerator.device,
# vae,
# [tokenizer1, tokenizer2],
# [text_encoder1, text_encoder2],
# unet,
# controlnet=control_net,
# )
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name,unwrap_model(controlnet))
save_model(ckpt_name, unwrap_model(control_net))
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -650,14 +592,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -668,7 +610,7 @@ def train(args):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name,unwrap_model(controlnet))
save_model(ckpt_name, unwrap_model(control_net))
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
@@ -688,13 +630,13 @@ def train(args):
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
controlnet=controlnet,
controlnet=control_net,
)
# end of epoch
if is_main_process:
controlnet = unwrap_model(controlnet)
control_net = unwrap_model(control_net)
accelerator.end_training()
@@ -703,9 +645,7 @@ def train(args):
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(
ckpt_name, controlnet, force_sync_upload=True
)
save_model(ckpt_name, control_net, force_sync_upload=True)
logger.info("model saved.")
@@ -717,26 +657,38 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
# train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
# train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--controlnet_model_name_or_path",
"--controlnet_model_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser