From f8850296c83ef2091bf1cb0f6e9ba462adfd9045 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 11:34:10 -0500 Subject: [PATCH] Fix validate epoch, cleanup imports --- train_network.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/train_network.py b/train_network.py index f3c8d8c9..11bba71e 100644 --- a/train_network.py +++ b/train_network.py @@ -3,15 +3,13 @@ import argparse import math import os import typing -from typing import List, Optional, Union +from typing import Any, List import sys import random import time import json from multiprocessing import Value -from typing import Any, List import toml -import itertools from tqdm import tqdm @@ -23,8 +21,8 @@ init_ipex() from accelerate.utils import set_seed from accelerate import Accelerator -from diffusers import DDPMScheduler, AutoencoderKL -from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers import DDPMScheduler +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from library import deepspeed_utils, model_util, strategy_base, strategy_sd import library.train_util as train_util @@ -49,7 +47,6 @@ from library.utils import setup_logging, add_logging_arguments setup_logging() import logging -import itertools logger = logging.getLogger(__name__) @@ -1442,7 +1439,7 @@ class NetworkTrainer: should_validate_epoch = ( (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None - else False + else True ) if should_validate_epoch and len(val_dataloader) > 0: