From fbaf373c8a88af398836d84937819a48313ad9f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Jan 2023 13:13:37 +0900 Subject: [PATCH] fix gradient accum not used for lr schduler --- train_network.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/train_network.py b/train_network.py index e557b1de..9f292b97 100644 --- a/train_network.py +++ b/train_network.py @@ -1,22 +1,15 @@ -import gc import importlib -import json -import shutil -import time import argparse +import gc import math import os from tqdm import tqdm import torch from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import numpy as np -import cv2 +from diffusers import DDPMScheduler -import library.model_util as model_util import library.train_util as train_util from library.train_util import DreamBoothDataset, FineTuningDataset @@ -139,7 +132,7 @@ def train(args): # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -216,17 +209,16 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 network.on_epoch_start(text_encoder, unet) loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): with torch.no_grad(): - # latentに変換 - if batch["latents"] is not None: + if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) else: + # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 b_size = latents.shape[0]