fix gradient accum not used for lr schduler

This commit is contained in:
Kohya S
2023-01-09 13:13:37 +09:00
parent 6b62c44022
commit fbaf373c8a

View File

@@ -1,22 +1,15 @@
import gc
import importlib import importlib
import json
import shutil
import time
import argparse import argparse
import gc
import math import math
import os import os
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
from transformers import CLIPTokenizer
import diffusers import diffusers
from diffusers import DDPMScheduler, StableDiffusionPipeline from diffusers import DDPMScheduler
import numpy as np
import cv2
import library.model_util as model_util
import library.train_util as train_util import library.train_util as train_util
from library.train_util import DreamBoothDataset, FineTuningDataset from library.train_util import DreamBoothDataset, FineTuningDataset
@@ -139,7 +132,7 @@ def train(args):
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: if args.full_fp16:
@@ -216,17 +209,16 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
# 指定したステップ数までText Encoderを学習するepoch最初の状態
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network): with accelerator.accumulate(network):
with torch.no_grad(): with torch.no_grad():
# latentに変換 if "latents" in batch and batch["latents"] is not None:
if batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)
else: else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * 0.18215
b_size = latents.shape[0] b_size = latents.shape[0]