mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix gradient accum not used for lr schduler
This commit is contained in:
@@ -1,22 +1,15 @@
|
||||
import gc
|
||||
import importlib
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
from transformers import CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
import numpy as np
|
||||
import cv2
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||
|
||||
@@ -139,7 +132,7 @@ def train(args):
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -216,17 +209,16 @@ def train(args):
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(network):
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
if batch["latents"] is not None:
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user