mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix gradient accum not used for lr schduler
This commit is contained in:
@@ -1,22 +1,15 @@
|
|||||||
import gc
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import gc
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from transformers import CLIPTokenizer
|
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
from diffusers import DDPMScheduler
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
import library.model_util as model_util
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import DreamBoothDataset, FineTuningDataset
|
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||||
|
|
||||||
@@ -139,7 +132,7 @@ def train(args):
|
|||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@@ -216,17 +209,16 @@ def train(args):
|
|||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# latentに変換
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
if batch["latents"] is not None:
|
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
else:
|
else:
|
||||||
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user