diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f009b577..aaf77b8d 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -12,7 +12,6 @@ from accelerate import init_empty_weights from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet -from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging setup_logging() @@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin def sample_images(*args, **kwargs): + from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline + return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) diff --git a/library/train_util.py b/library/train_util.py index b559616f..07c253a0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ import hashlib import subprocess from io import BytesIO import toml +# from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm @@ -912,6 +913,23 @@ class BaseDataset(torch.utils.data.Dataset): if info.image_size is None: info.image_size = self.get_image_size(info.absolute_path) + # # run in parallel + # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) + # with ThreadPoolExecutor(max_workers) as executor: + # futures = [] + # for info in tqdm(self.image_data.values(), desc="loading image sizes"): + # if info.image_size is None: + # def get_and_set_image_size(info): + # info.image_size = self.get_image_size(info.absolute_path) + # futures.append(executor.submit(get_and_set_image_size, info)) + # # consume futures to reduce memory usage and prevent Ctrl-C hang + # if len(futures) >= max_workers: + # for future in futures: + # future.result() + # futures = [] + # for future in futures: + # future.result() + if self.enable_bucket: logger.info("make buckets") else: @@ -1826,7 +1844,7 @@ class DreamBoothDataset(BaseDataset): # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] missing_captions = [] - for img_path in img_paths: + for img_path in tqdm(img_paths, desc="read caption"): cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) if cap_for_img is None and subset.class_tokens is None: logger.warning( diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 583a27dc..b902cda6 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -253,11 +253,20 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = list(control_net.parameters()) - # for p in trainable_params: - # p.requires_grad = True - logger.info(f"trainable params count: {len(trainable_params)}") - logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + trainable_params = [] + ctrlnet_params = [] + unet_params = [] + for name, param in control_net.named_parameters(): + if name.startswith("controlnet_"): + ctrlnet_params.append(param) + else: + unet_params.append(param) + trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr}) + trainable_params.append({"params": unet_params, "lr": args.learning_rate}) + all_params = ctrlnet_params + unet_params + + logger.info(f"trainable params count: {len(all_params)}") + logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}") _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -456,6 +465,8 @@ def train(args): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + control_net.train() + for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(control_net): @@ -510,6 +521,9 @@ def train(args): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + # '-1 to +1' to '0 to 1' + controlnet_image = (controlnet_image + 1) / 2 + with accelerator.autocast(): input_resi_add, mid_add = control_net( noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image @@ -690,6 +704,12 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--control_net_lr", + type=float, + default=1e-4, + help="learning rate for controlnet / controlnetの学習率", + ) return parser