update diffusers to 1.16 | train_db

This commit is contained in:
ddPn08
2023-06-01 20:15:06 +09:00
parent e743ee5d5c
commit 1214f35985

View File

@@ -2,18 +2,15 @@
# XXX dropped option: fine_tune # XXX dropped option: fine_tune
import gc import gc
import time
import argparse import argparse
import itertools import itertools
import math import math
import os import os
import toml
from multiprocessing import Value from multiprocessing import Value
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
import library.train_util as train_util import library.train_util as train_util
@@ -138,7 +135,7 @@ def train(args):
unet.requires_grad_(True) # 念のため追加 unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder) text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder: if not train_text_encoder:
print("Text Encoder is not trained.") accelerator.print("Text Encoder is not trained.")
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
@@ -150,7 +147,7 @@ def train(args):
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype)
# 学習に必要なクラスを準備する # 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.") accelerator.print("prepare optimizer, data loader etc.")
if train_text_encoder: if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
else: else:
@@ -175,7 +172,7 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
) )
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信 # データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps) train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -191,7 +188,7 @@ def train(args):
assert ( assert (
args.mixed_precision == "fp16" args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.") accelerator.print("enable full fp16 training.")
unet.to(weight_dtype) unet.to(weight_dtype)
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
@@ -224,15 +221,15 @@ def train(args):
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始") accelerator.print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}") accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0 global_step = 0
@@ -247,7 +244,7 @@ def train(args):
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態 # 指定したステップ数までText Encoderを学習するepoch最初の状態
@@ -260,7 +257,7 @@ def train(args):
current_step.value = global_step current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める # 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training: if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}") accelerator.print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing: if not args.gradient_checkpointing:
text_encoder.train(False) text_encoder.train(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)