From 1214f35985b233788d835b7380b4c315f18d59bb Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Thu, 1 Jun 2023 20:15:06 +0900 Subject: [PATCH] update diffusers to 1.16 | train_db --- train_db.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/train_db.py b/train_db.py index 7ec06354..2fd1c7c5 100644 --- a/train_db.py +++ b/train_db.py @@ -2,18 +2,15 @@ # XXX dropped option: fine_tune import gc -import time import argparse import itertools import math import os -import toml from multiprocessing import Value from tqdm import tqdm import torch from accelerate.utils import set_seed -import diffusers from diffusers import DDPMScheduler import library.train_util as train_util @@ -138,7 +135,7 @@ def train(args): unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) if not train_text_encoder: - print("Text Encoder is not trained.") + accelerator.print("Text Encoder is not trained.") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -150,7 +147,7 @@ def train(args): vae.to(accelerator.device, dtype=weight_dtype) # 学習に必要なクラスを準備する - print("prepare optimizer, data loader etc.") + accelerator.print("prepare optimizer, data loader etc.") if train_text_encoder: trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters()) else: @@ -175,7 +172,7 @@ def train(args): args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) - print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) @@ -191,7 +188,7 @@ def train(args): assert ( args.mixed_precision == "fp16" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print("enable full fp16 training.") + accelerator.print("enable full fp16 training.") unet.to(weight_dtype) text_encoder.to(weight_dtype) @@ -224,15 +221,15 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps - print("running training / 学習開始") - print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") - print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") - print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") - print(f" num epochs / epoch数: {num_train_epochs}") - print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") - print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 @@ -247,7 +244,7 @@ def train(args): loss_list = [] loss_total = 0.0 for epoch in range(num_train_epochs): - print(f"\nepoch {epoch+1}/{num_train_epochs}") + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 @@ -260,7 +257,7 @@ def train(args): current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: - print(f"stop text encoder training at step {global_step}") + accelerator.print(f"stop text encoder training at step {global_step}") if not args.gradient_checkpointing: text_encoder.train(False) text_encoder.requires_grad_(False)