Merge branch 'original-u-net' into sdxl

This commit is contained in:
Kohya S
2023-06-24 09:37:00 +09:00
2 changed files with 118 additions and 110 deletions

View File

@@ -1468,6 +1468,8 @@ class UNet2DConditionModel(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[Dict, Tuple]: ) -> Union[Dict, Tuple]:
r""" r"""
Args: Args:
@@ -1533,9 +1535,20 @@ class UNet2DConditionModel(nn.Module):
down_block_res_samples += res_samples down_block_res_samples += res_samples
# skip connectionにControlNetの出力を追加する
if down_block_additional_residuals is not None:
down_block_res_samples = list(down_block_res_samples)
for i in range(len(down_block_res_samples)):
down_block_res_samples[i] += down_block_additional_residuals[i]
down_block_res_samples = tuple(down_block_res_samples)
# 4. mid # 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# ControlNetの出力を追加する
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up # 5. up
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1 is_final_block = i == len(self.up_blocks) - 1

View File

@@ -6,6 +6,7 @@ import os
import random import random
import time import time
from multiprocessing import Value from multiprocessing import Value
from types import SimpleNamespace
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -39,17 +40,14 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
} }
if args.optimizer_type.lower().startswith("DAdapt".lower()): if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = ( logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
lr_scheduler.optimizers[-1].param_groups[0]["d"]
* lr_scheduler.optimizers[-1].param_groups[0]["lr"]
)
return logs return logs
def train(args): def train(args):
session_id = random.randint(0, 2**32) # session_id = random.randint(0, 2**32)
training_started_at = time.time() # training_started_at = time.time()
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -88,15 +86,11 @@ def train(args):
} }
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint( train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint.dataset_group
)
current_epoch = Value("i", 0) current_epoch = Value("i", 0)
current_step = Value("i", 0) current_step = Value("i", 0)
ds_for_collater = ( ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
train_dataset_group if args.max_data_loader_n_workers == 0 else None
)
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset: if args.debug_dataset:
@@ -115,7 +109,7 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
@@ -126,6 +120,69 @@ def train(args):
args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
) )
# DiffusersのControlNetが使用するデータを準備する
if args.v2:
unet.config = {
"act_fn": "silu",
"attention_head_dim": [5, 10, 20, 20],
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 1024,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"dual_cross_attention": False,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
"sample_size": 96,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"use_linear_projection": True,
"upcast_attention": True,
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": True,
"class_embed_type": None,
"num_class_embeds": None,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
else:
unet.config = {
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 768,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"out_channels": 4,
"sample_size": 64,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": False,
"class_embed_type": None,
"num_class_embeds": None,
"upcast_attention": False,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
unet.config = SimpleNamespace(**unet.config)
controlnet = ControlNetModel.from_unet(unet) controlnet = ControlNetModel.from_unet(unet)
if args.controlnet_model_name_or_path: if args.controlnet_model_name_or_path:
@@ -140,9 +197,8 @@ def train(args):
elif os.path.isdir(filename): elif os.path.isdir(filename):
controlnet = ControlNetModel.from_pretrained(filename) controlnet = ControlNetModel.from_pretrained(filename)
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
@@ -171,15 +227,11 @@ def train(args):
trainable_params = controlnet.parameters() trainable_params = controlnet.parameters()
_, _, optimizer = train_util.get_optimizer( _, _, optimizer = train_util.get_optimizer(args, trainable_params)
args, trainable_params
)
# dataloaderを準備する # dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる # DataLoaderのプロセス数0はメインプロセスになる
n_workers = min( n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
args.max_data_loader_n_workers, os.cpu_count() - 1
) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset_group, train_dataset_group,
@@ -193,21 +245,15 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
/ accelerator.num_processes
/ args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
) )
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信 # データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps) train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix( lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
args, optimizer, accelerator.num_processes
)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: if args.full_fp16:
@@ -245,31 +291,21 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args) train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil( num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
len(train_dataloader) / args.gradient_accumulation_steps
)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = ( args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
)
# 学習する # 学習する
# TODO: find a way to handle total batch size when there are multiple datasets # TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始") accelerator.print("running training / 学習開始")
accelerator.print( accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
)
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print( accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print( accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
)
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm( progress_bar = tqdm(
@@ -288,11 +324,7 @@ def train(args):
clip_sample=False, clip_sample=False,
) )
if accelerator.is_main_process: if accelerator.is_main_process:
accelerator.init_trackers( accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name)
"controlnet_train"
if args.log_tracker_name is None
else args.log_tracker_name
)
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0
@@ -321,9 +353,7 @@ def train(args):
torch.save(state_dict, ckpt_file) torch.save(state_dict, ckpt_file)
if args.huggingface_repo_id is not None: if args.huggingface_repo_id is not None:
huggingface_util.upload( huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload
)
def remove_model(old_ckpt_name): def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
@@ -345,23 +375,17 @@ def train(args):
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)
else: else:
# latentに変換 # latentに変換
latents = vae.encode( latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
batch["images"].to(dtype=weight_dtype)
).latent_dist.sample()
latents = latents * 0.18215 latents = latents * 0.18215
b_size = latents.shape[0] b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device) input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states( encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
args, input_ids, tokenizer, text_encoder, weight_dtype
)
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset: if args.noise_offset:
noise = apply_noise_offset( noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
latents, noise, args.noise_offset, args.adaptive_noise_scale
)
elif args.multires_noise_iterations: elif args.multires_noise_iterations:
noise = pyramid_noise_like( noise = pyramid_noise_like(
noise, noise,
@@ -398,13 +422,8 @@ def train(args):
noisy_latents, noisy_latents,
timesteps, timesteps,
encoder_hidden_states, encoder_hidden_states,
down_block_additional_residuals=[ down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
sample.to(dtype=weight_dtype) mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(
dtype=weight_dtype
),
).sample ).sample
if args.v_parameterization: if args.v_parameterization:
@@ -413,18 +432,14 @@ def train(args):
else: else:
target = noise target = noise
loss = torch.nn.functional.mse_loss( loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
noise_pred.float(), target.float(), reduction="none"
)
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight( loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss, timesteps, noise_scheduler, args.min_snr_gamma
)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -456,31 +471,21 @@ def train(args):
) )
# 指定ステップごとにモデルを保存 # 指定ステップごとにモデルを保存
if ( if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
args.save_every_n_steps is not None
and global_step % args.save_every_n_steps == 0
):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name( ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
args, "." + args.save_model_as, global_step
)
save_model( save_model(
ckpt_name, unwrap_model(controlnet), ckpt_name,
accelerator.unwrap_model(controlnet),
) )
if args.save_state: if args.save_state:
train_util.save_and_remove_state_stepwise( train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
args, accelerator, global_step
)
remove_step_no = train_util.get_remove_step_no( remove_step_no = train_util.get_remove_step_no(args, global_step)
args, global_step
)
if remove_step_no is not None: if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name( remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
args, "." + args.save_model_as, remove_step_no
)
remove_model(remove_ckpt_name) remove_model(remove_ckpt_name)
current_loss = loss.detach().item() current_loss = loss.detach().item()
@@ -509,26 +514,18 @@ def train(args):
# 指定エポックごとにモデルを保存 # 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and ( saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
epoch + 1
) < num_train_epochs
if is_main_process and saving: if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name( ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
args, "." + args.save_model_as, epoch + 1 save_model(ckpt_name, accelerator.unwrap_model(controlnet))
)
save_model(ckpt_name, unwrap_model(controlnet))
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None: if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name( remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
args, "." + args.save_model_as, remove_epoch_no
)
remove_model(remove_ckpt_name) remove_model(remove_ckpt_name)
if args.save_state: if args.save_state:
train_util.save_and_remove_state_on_epoch_end( train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
args, accelerator, epoch + 1
)
train_util.sample_images( train_util.sample_images(
accelerator, accelerator,
@@ -545,20 +542,18 @@ def train(args):
# end of epoch # end of epoch
if is_main_process: if is_main_process:
controlnet = unwrap_model(controlnet) controlnet = accelerator.unwrap_model(controlnet)
accelerator.end_training() accelerator.end_training()
if is_main_process and args.save_state: if is_main_process and args.save_state:
train_util.save_state_on_train_end(args, accelerator) train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
if is_main_process: if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model( save_model(ckpt_name, controlnet, force_sync_upload=True)
ckpt_name, controlnet, force_sync_upload=True
)
print("model saved.") print("model saved.")