Merge remote-tracking branch 'hina/feature/val-loss' into validation-loss-upstream

Modified implementation for process_batch and cleanup validation
recording
This commit is contained in:
rockerBOO
2025-01-03 00:48:08 -05:00
85 changed files with 23666 additions and 1552 deletions

View File

@@ -1,40 +1,4 @@
import argparse
import json
import math
import os
import random
import time
from multiprocessing import Value
from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
from library.utils import setup_logging, add_logging_arguments
from library.utils import setup_logging
setup_logging()
import logging
@@ -42,6 +6,7 @@ import logging
logger = logging.getLogger(__name__)
<<<<<<< HEAD
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
@@ -608,8 +573,16 @@ def setup_parser() -> argparse.ArgumentParser:
return parser
=======
from library import train_util
from train_control_net import setup_parser, train
>>>>>>> hina/feature/val-loss
if __name__ == "__main__":
logger.warning(
"The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead"
" / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。"
)
parser = setup_parser()
args = parser.parse_args()