mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update diffusers to 1.16 | train_db
This commit is contained in:
33
train_db.py
33
train_db.py
@@ -2,18 +2,15 @@
|
|||||||
# XXX dropped option: fine_tune
|
# XXX dropped option: fine_tune
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
@@ -138,7 +135,7 @@ def train(args):
|
|||||||
unet.requires_grad_(True) # 念のため追加
|
unet.requires_grad_(True) # 念のため追加
|
||||||
text_encoder.requires_grad_(train_text_encoder)
|
text_encoder.requires_grad_(train_text_encoder)
|
||||||
if not train_text_encoder:
|
if not train_text_encoder:
|
||||||
print("Text Encoder is not trained.")
|
accelerator.print("Text Encoder is not trained.")
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
@@ -150,7 +147,7 @@ def train(args):
|
|||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
|
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||||
else:
|
else:
|
||||||
@@ -175,7 +172,7 @@ def train(args):
|
|||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
@@ -191,7 +188,7 @@ def train(args):
|
|||||||
assert (
|
assert (
|
||||||
args.mixed_precision == "fp16"
|
args.mixed_precision == "fp16"
|
||||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
print("enable full fp16 training.")
|
accelerator.print("enable full fp16 training.")
|
||||||
unet.to(weight_dtype)
|
unet.to(weight_dtype)
|
||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
@@ -224,15 +221,15 @@ def train(args):
|
|||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
print("running training / 学習開始")
|
accelerator.print("running training / 学習開始")
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
global_step = 0
|
global_step = 0
|
||||||
@@ -247,7 +244,7 @@ def train(args):
|
|||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
@@ -260,7 +257,7 @@ def train(args):
|
|||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
# 指定したステップ数でText Encoderの学習を止める
|
# 指定したステップ数でText Encoderの学習を止める
|
||||||
if global_step == args.stop_text_encoder_training:
|
if global_step == args.stop_text_encoder_training:
|
||||||
print(f"stop text encoder training at step {global_step}")
|
accelerator.print(f"stop text encoder training at step {global_step}")
|
||||||
if not args.gradient_checkpointing:
|
if not args.gradient_checkpointing:
|
||||||
text_encoder.train(False)
|
text_encoder.train(False)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
|
|||||||
Reference in New Issue
Block a user