mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix dataloader
This commit is contained in:
@@ -11,31 +11,36 @@
|
||||
# - Per-block fused optimizer instances
|
||||
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from multiprocessing import Value
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
from library import (
|
||||
deepspeed_utils,
|
||||
flux_train_utils,
|
||||
flux_utils,
|
||||
strategy_base,
|
||||
strategy_flux,
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
from library.utils import add_logging_arguments, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -46,10 +51,10 @@ import library.config_util as config_util
|
||||
|
||||
# import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
ConfigSanitizer,
|
||||
)
|
||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||
from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -85,7 +90,6 @@ def train(args):
|
||||
)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
@@ -103,7 +107,7 @@ def train(args):
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
ignored = ["train_data_dir", "conditioing_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
@@ -111,31 +115,17 @@ def train(args):
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
||||
args.train_data_dir,
|
||||
args.conditioning_data_dir,
|
||||
args.caption_extension
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
@@ -648,12 +638,12 @@ def train(args):
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
|
||||
with accelerator.autocast():
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=batch["control_image"].to(accelerator.device),
|
||||
controlnet_img=batch["conditioing_image"].to(accelerator.device),
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
@@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user