fix dataloader

This commit is contained in:
minux302
2024-11-16 14:49:29 +09:00
parent 42f6edf3a8
commit e358b118af
2 changed files with 52 additions and 49 deletions

View File

@@ -11,31 +11,36 @@
# - Per-block fused optimizer instances
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import math
import os
from multiprocessing import Value
import time
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Value
from typing import List, Optional, Tuple, Union
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from tqdm import tqdm
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
from library import (
deepspeed_utils,
flux_train_utils,
flux_utils,
strategy_base,
strategy_flux,
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.utils import add_logging_arguments, setup_logging
setup_logging()
import logging
@@ -46,10 +51,10 @@ import library.config_util as config_util
# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
ConfigSanitizer,
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss
def train(args):
@@ -85,7 +90,6 @@ def train(args):
)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
@@ -103,7 +107,7 @@ def train(args):
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
ignored = ["train_data_dir", "conditioing_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
@@ -111,31 +115,17 @@ def train(args):
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
@@ -648,12 +638,12 @@ def train(args):
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
with accelerator.autocast():
block_samples, block_single_samples = controlnet(
img=packed_noisy_model_input,
img_ids=img_ids,
controlnet_cond=batch["control_image"].to(accelerator.device),
controlnet_img=batch["conditioing_image"].to(accelerator.device),
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
@@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
return parser