mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Merge remote-tracking branch 'hina/feature/val-loss' into validation-loss-upstream
Modified implementation for process_batch and cleanup validation recording
This commit is contained in:
@@ -1,40 +1,4 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Value
|
||||
from types import SimpleNamespace
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import deepspeed_utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -42,6 +6,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {
|
||||
@@ -608,8 +573,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
return parser
|
||||
|
||||
=======
|
||||
from library import train_util
|
||||
from train_control_net import setup_parser, train
|
||||
>>>>>>> hina/feature/val-loss
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.warning(
|
||||
"The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead"
|
||||
" / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。"
|
||||
)
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user