mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix validate epoch, cleanup imports
This commit is contained in:
@@ -3,15 +3,13 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import typing
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
import toml
|
||||
import itertools
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -23,8 +21,8 @@ init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DDPMScheduler, AutoencoderKL
|
||||
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
||||
from diffusers import DDPMScheduler
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
|
||||
import library.train_util as train_util
|
||||
@@ -49,7 +47,6 @@ from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
import itertools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1442,7 +1439,7 @@ class NetworkTrainer:
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||
if args.validate_every_n_epochs is not None
|
||||
else False
|
||||
else True
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
|
||||
Reference in New Issue
Block a user