mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix dataloader
This commit is contained in:
@@ -11,31 +11,36 @@
|
|||||||
# - Per-block fused optimizer instances
|
# - Per-block fused optimizer instances
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from multiprocessing import Value
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from library import utils
|
from library import utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import clean_memory_on_device, init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
|
|
||||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from library import (
|
||||||
from library.utils import setup_logging, add_logging_arguments
|
deepspeed_utils,
|
||||||
|
flux_train_utils,
|
||||||
|
flux_utils,
|
||||||
|
strategy_base,
|
||||||
|
strategy_flux,
|
||||||
|
)
|
||||||
|
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||||
|
from library.utils import add_logging_arguments, setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -46,10 +51,10 @@ import library.config_util as config_util
|
|||||||
|
|
||||||
# import library.sdxl_train_util as sdxl_train_util
|
# import library.sdxl_train_util as sdxl_train_util
|
||||||
from library.config_util import (
|
from library.config_util import (
|
||||||
ConfigSanitizer,
|
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
|
ConfigSanitizer,
|
||||||
)
|
)
|
||||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@@ -85,7 +90,6 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
use_dreambooth_method = args.in_json is None
|
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
set_seed(args.seed) # 乱数系列を初期化する
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
@@ -103,7 +107,7 @@ def train(args):
|
|||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "in_json"]
|
ignored = ["train_data_dir", "conditioing_data_dir"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
@@ -111,31 +115,17 @@ def train(args):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_dreambooth_method:
|
user_config = {
|
||||||
logger.info("Using DreamBooth method.")
|
"datasets": [
|
||||||
user_config = {
|
{
|
||||||
"datasets": [
|
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
||||||
{
|
args.train_data_dir,
|
||||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
args.conditioning_data_dir,
|
||||||
args.train_data_dir, args.reg_data_dir
|
args.caption_extension
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
logger.info("Training with captions.")
|
|
||||||
user_config = {
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"subsets": [
|
|
||||||
{
|
|
||||||
"image_dir": args.train_data_dir,
|
|
||||||
"metadata_file": args.in_json,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
@@ -653,7 +643,7 @@ def train(args):
|
|||||||
block_samples, block_single_samples = controlnet(
|
block_samples, block_single_samples = controlnet(
|
||||||
img=packed_noisy_model_input,
|
img=packed_noisy_model_input,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
controlnet_cond=batch["control_image"].to(accelerator.device),
|
controlnet_img=batch["conditioing_image"].to(accelerator.device),
|
||||||
txt=t5_out,
|
txt=t5_out,
|
||||||
txt_ids=txt_ids,
|
txt_ids=txt_ids,
|
||||||
y=l_pooled,
|
y=l_pooled,
|
||||||
@@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--controlnet_model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--conditioning_data_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,15 @@
|
|||||||
# license: Apache-2.0 License
|
# license: Apache-2.0 License
|
||||||
|
|
||||||
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from library import utils
|
from library import utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import clean_memory_on_device, init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
@@ -18,6 +18,7 @@ import torch
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from library import custom_offloading_utils
|
from library import custom_offloading_utils
|
||||||
|
|
||||||
# USE_REENTRANT = True
|
# USE_REENTRANT = True
|
||||||
@@ -1251,7 +1252,7 @@ class ControlNetFlux(nn.Module):
|
|||||||
self,
|
self,
|
||||||
img: Tensor,
|
img: Tensor,
|
||||||
img_ids: Tensor,
|
img_ids: Tensor,
|
||||||
controlnet_cond: Tensor,
|
controlnet_img: Tensor,
|
||||||
txt: Tensor,
|
txt: Tensor,
|
||||||
txt_ids: Tensor,
|
txt_ids: Tensor,
|
||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
@@ -1264,10 +1265,10 @@ class ControlNetFlux(nn.Module):
|
|||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
controlnet_img = self.input_hint_block(controlnet_img)
|
||||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
controlnet_img = self.pos_embed_input(controlnet_img)
|
||||||
img = img + controlnet_cond
|
img = img + controlnet_img
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user