Fix validate epoch, cleanup imports

This commit is contained in:
rockerBOO
2025-01-06 11:34:10 -05:00
parent c64d1a22fc
commit f8850296c8

View File

@@ -3,15 +3,13 @@ import argparse
import math import math
import os import os
import typing import typing
from typing import List, Optional, Union from typing import Any, List
import sys import sys
import random import random
import time import time
import json import json
from multiprocessing import Value from multiprocessing import Value
from typing import Any, List
import toml import toml
import itertools
from tqdm import tqdm from tqdm import tqdm
@@ -23,8 +21,8 @@ init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
from accelerate import Accelerator from accelerate import Accelerator
from diffusers import DDPMScheduler, AutoencoderKL from diffusers import DDPMScheduler
from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from library import deepspeed_utils, model_util, strategy_base, strategy_sd from library import deepspeed_utils, model_util, strategy_base, strategy_sd
import library.train_util as train_util import library.train_util as train_util
@@ -49,7 +47,6 @@ from library.utils import setup_logging, add_logging_arguments
setup_logging() setup_logging()
import logging import logging
import itertools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -1442,7 +1439,7 @@ class NetworkTrainer:
should_validate_epoch = ( should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0 (epoch + 1) % args.validate_every_n_epochs == 0
if args.validate_every_n_epochs is not None if args.validate_every_n_epochs is not None
else False else True
) )
if should_validate_epoch and len(val_dataloader) > 0: if should_validate_epoch and len(val_dataloader) > 0: