mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
fix training starts with debug_dataset
This commit is contained in:
24
train_db.py
24
train_db.py
@@ -3,33 +3,19 @@
|
||||
|
||||
import gc
|
||||
import time
|
||||
from torch.autograd.function import Function
|
||||
import argparse
|
||||
import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed
|
||||
from transformers import CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||
from library.train_util import DreamBoothDataset
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
@@ -52,11 +38,12 @@ def train(args):
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
if args.no_token_padding:
|
||||
train_dataset.disable_padding()
|
||||
train_dataset.disable_token_padding()
|
||||
train_dataset.make_buckets()
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset)
|
||||
return
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
@@ -311,8 +298,7 @@ def train(args):
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
print("saving last state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME))
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
|
||||
Reference in New Issue
Block a user