mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix cond image normlization, add independent LR for control
This commit is contained in:
@@ -12,7 +12,6 @@ from accelerate import init_empty_weights
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
||||
|
||||
|
||||
def sample_images(*args, **kwargs):
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
|
||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
@@ -31,6 +31,7 @@ import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
import toml
|
||||
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -912,6 +913,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if info.image_size is None:
|
||||
info.image_size = self.get_image_size(info.absolute_path)
|
||||
|
||||
# # run in parallel
|
||||
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
|
||||
# with ThreadPoolExecutor(max_workers) as executor:
|
||||
# futures = []
|
||||
# for info in tqdm(self.image_data.values(), desc="loading image sizes"):
|
||||
# if info.image_size is None:
|
||||
# def get_and_set_image_size(info):
|
||||
# info.image_size = self.get_image_size(info.absolute_path)
|
||||
# futures.append(executor.submit(get_and_set_image_size, info))
|
||||
# # consume futures to reduce memory usage and prevent Ctrl-C hang
|
||||
# if len(futures) >= max_workers:
|
||||
# for future in futures:
|
||||
# future.result()
|
||||
# futures = []
|
||||
# for future in futures:
|
||||
# future.result()
|
||||
|
||||
if self.enable_bucket:
|
||||
logger.info("make buckets")
|
||||
else:
|
||||
@@ -1826,7 +1844,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
for img_path in tqdm(img_paths, desc="read caption"):
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
|
||||
@@ -253,11 +253,20 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = list(control_net.parameters())
|
||||
# for p in trainable_params:
|
||||
# p.requires_grad = True
|
||||
logger.info(f"trainable params count: {len(trainable_params)}")
|
||||
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||
trainable_params = []
|
||||
ctrlnet_params = []
|
||||
unet_params = []
|
||||
for name, param in control_net.named_parameters():
|
||||
if name.startswith("controlnet_"):
|
||||
ctrlnet_params.append(param)
|
||||
else:
|
||||
unet_params.append(param)
|
||||
trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr})
|
||||
trainable_params.append({"params": unet_params, "lr": args.learning_rate})
|
||||
all_params = ctrlnet_params + unet_params
|
||||
|
||||
logger.info(f"trainable params count: {len(all_params)}")
|
||||
logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}")
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
@@ -456,6 +465,8 @@ def train(args):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
control_net.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(control_net):
|
||||
@@ -510,6 +521,9 @@ def train(args):
|
||||
|
||||
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
||||
|
||||
# '-1 to +1' to '0 to 1'
|
||||
controlnet_image = (controlnet_image + 1) / 2
|
||||
|
||||
with accelerator.autocast():
|
||||
input_resi_add, mid_add = control_net(
|
||||
noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image
|
||||
@@ -690,6 +704,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control_net_lr",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="learning rate for controlnet / controlnetの学習率",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user