fix cond image normlization, add independent LR for control

This commit is contained in:
Kohya S
2024-10-03 21:32:21 +09:00
parent 793999d116
commit c2440f9e53
3 changed files with 46 additions and 7 deletions

View File

@@ -12,7 +12,6 @@ from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
@@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
def sample_images(*args, **kwargs):
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)

View File

@@ -31,6 +31,7 @@ import hashlib
import subprocess
from io import BytesIO
import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
@@ -912,6 +913,23 @@ class BaseDataset(torch.utils.data.Dataset):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
# # run in parallel
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
# with ThreadPoolExecutor(max_workers) as executor:
# futures = []
# for info in tqdm(self.image_data.values(), desc="loading image sizes"):
# if info.image_size is None:
# def get_and_set_image_size(info):
# info.image_size = self.get_image_size(info.absolute_path)
# futures.append(executor.submit(get_and_set_image_size, info))
# # consume futures to reduce memory usage and prevent Ctrl-C hang
# if len(futures) >= max_workers:
# for future in futures:
# future.result()
# futures = []
# for future in futures:
# future.result()
if self.enable_bucket:
logger.info("make buckets")
else:
@@ -1826,7 +1844,7 @@ class DreamBoothDataset(BaseDataset):
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
for img_path in tqdm(img_paths, desc="read caption"):
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(

View File

@@ -253,11 +253,20 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(control_net.parameters())
# for p in trainable_params:
# p.requires_grad = True
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
trainable_params = []
ctrlnet_params = []
unet_params = []
for name, param in control_net.named_parameters():
if name.startswith("controlnet_"):
ctrlnet_params.append(param)
else:
unet_params.append(param)
trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr})
trainable_params.append({"params": unet_params, "lr": args.learning_rate})
all_params = ctrlnet_params + unet_params
logger.info(f"trainable params count: {len(all_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -456,6 +465,8 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
control_net.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(control_net):
@@ -510,6 +521,9 @@ def train(args):
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
# '-1 to +1' to '0 to 1'
controlnet_image = (controlnet_image + 1) / 2
with accelerator.autocast():
input_resi_add, mid_add = control_net(
noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image
@@ -690,6 +704,12 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--control_net_lr",
type=float,
default=1e-4,
help="learning rate for controlnet / controlnetの学習率",
)
return parser