From e358b118afbc93f63dbb5ab6d2412ec553ea9cd7 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sat, 16 Nov 2024 14:49:29 +0900 Subject: [PATCH] fix dataloader --- flux_train_control_net.py | 84 ++++++++++++++++++++------------------- library/flux_models.py | 17 ++++---- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 8a7be75f..ee4d0ebf 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -11,31 +11,36 @@ # - Per-block fused optimizer instances import argparse -from concurrent.futures import ThreadPoolExecutor import copy import math import os -from multiprocessing import Value import time +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Value from typing import List, Optional, Tuple, Union + import toml - -from tqdm import tqdm - import torch import torch.nn as nn +from tqdm import tqdm + from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() from accelerate.utils import set_seed -from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux -from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler import library.train_util as train_util - -from library.utils import setup_logging, add_logging_arguments +from library import ( + deepspeed_utils, + flux_train_utils, + flux_utils, + strategy_base, + strategy_flux, +) +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler +from library.utils import add_logging_arguments, setup_logging setup_logging() import logging @@ -46,10 +51,10 @@ import library.config_util as config_util # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( - ConfigSanitizer, BlueprintGenerator, + ConfigSanitizer, ) -from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments +from library.custom_train_functions import add_custom_train_arguments, apply_masked_loss def train(args): @@ -85,7 +90,6 @@ def train(args): ) cache_latents = args.cache_latents - use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # 乱数系列を初期化する @@ -103,7 +107,7 @@ def train(args): if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) - ignored = ["train_data_dir", "in_json"] + ignored = ["train_data_dir", "conditioing_data_dir"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning( "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( @@ -111,31 +115,17 @@ def train(args): ) ) else: - if use_dreambooth_method: - logger.info("Using DreamBooth method.") - user_config = { - "datasets": [ - { - "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( - args.train_data_dir, args.reg_data_dir - ) - } - ] - } - else: - logger.info("Training with captions.") - user_config = { - "datasets": [ - { - "subsets": [ - { - "image_dir": args.train_data_dir, - "metadata_file": args.in_json, - } - ] - } - ] - } + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension + ) + } + ] + } blueprint = blueprint_generator.generate(user_config, args) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) @@ -648,12 +638,12 @@ def train(args): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None - + with accelerator.autocast(): block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, - controlnet_cond=batch["control_image"].to(accelerator.device), + controlnet_img=batch["conditioing_image"].to(accelerator.device), txt=t5_out, txt_ids=txt_ids, y=l_pooled, @@ -856,6 +846,18 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="controlnet model name or path / controlnetのモデル名またはパス", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index a3bd1974..b52ea6f0 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -2,15 +2,15 @@ # license: Apache-2.0 License -from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass import math import os import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass from typing import Dict, List, Optional, Union from library import utils -from library.device_utils import init_ipex, clean_memory_on_device +from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -18,6 +18,7 @@ import torch from einops import rearrange from torch import Tensor, nn from torch.utils.checkpoint import checkpoint + from library import custom_offloading_utils # USE_REENTRANT = True @@ -1251,7 +1252,7 @@ class ControlNetFlux(nn.Module): self, img: Tensor, img_ids: Tensor, - controlnet_cond: Tensor, + controlnet_img: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, @@ -1264,10 +1265,10 @@ class ControlNetFlux(nn.Module): # running on sequences img img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - controlnet_cond = self.pos_embed_input(controlnet_cond) - img = img + controlnet_cond + controlnet_img = self.input_hint_block(controlnet_img) + controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_img = self.pos_embed_input(controlnet_img) + img = img + controlnet_img vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: