fix training starts with debug_dataset

This commit is contained in:
Kohya S
2023-01-07 20:19:25 +09:00
parent 40d1bf3809
commit 2efced0a9a

View File

@@ -3,33 +3,19 @@
import gc
import time
from torch.autograd.function import Function
import argparse
import glob
import itertools
import math
import os
import random
import shutil
from tqdm import tqdm
import torch
from torchvision import transforms
from accelerate import Accelerator
from accelerate.utils import set_seed
from transformers import CLIPTokenizer
import diffusers
from diffusers import DDPMScheduler, StableDiffusionPipeline
import albumentations as albu
import numpy as np
from PIL import Image
import cv2
from einops import rearrange
from torch import einsum
from diffusers import DDPMScheduler
import library.model_util as model_util
import library.train_util as train_util
from library.train_util import DreamBoothDataset, FineTuningDataset
from library.train_util import DreamBoothDataset
def collate_fn(examples):
@@ -52,11 +38,12 @@ def train(args):
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
if args.no_token_padding:
train_dataset.disable_padding()
train_dataset.disable_token_padding()
train_dataset.make_buckets()
if args.debug_dataset:
train_util.debug_dataset(train_dataset)
return
# acceleratorを準備する
print("prepare accelerator")
@@ -311,8 +298,7 @@ def train(args):
accelerator.end_training()
if args.save_state:
print("saving last state.")
accelerator.save_state(os.path.join(args.output_dir, train_util.LAST_STATE_NAME))
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す