update diffusers to 1.16 | finetune

This commit is contained in:
ddPn08
2023-06-01 20:47:54 +09:00
parent 1214f35985
commit 4f8ce00477

View File

@@ -5,13 +5,11 @@ import argparse
import gc import gc
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
@@ -128,11 +126,11 @@ def train(args):
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers: if args.diffusers_xformers:
print("Use xformers by Diffusers") accelerator.print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True) set_diffusers_xformers_flag(unet, True)
else: else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers") accelerator.print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False) set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -157,7 +155,7 @@ def train(args):
training_models.append(unet) training_models.append(unet)
if args.train_text_encoder: if args.train_text_encoder:
print("enable text encoder training") accelerator.print("enable text encoder training")
if args.gradient_checkpointing: if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder) training_models.append(text_encoder)
@@ -183,7 +181,7 @@ def train(args):
params_to_optimize = params params_to_optimize = params
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する # dataloaderを準備する
@@ -203,7 +201,7 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
) )
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信 # データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps) train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -216,7 +214,7 @@ def train(args):
assert ( assert (
args.mixed_precision == "fp16" args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.") accelerator.print("enable full fp16 training.")
unet.to(weight_dtype) unet.to(weight_dtype)
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
@@ -246,14 +244,14 @@ def train(args):
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始") accelerator.print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
@@ -266,7 +264,7 @@ def train(args):
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
for m in training_models: for m in training_models: