fix cond image normlization, add independent LR for control

This commit is contained in:
Kohya S
2024-10-03 21:32:21 +09:00
parent 793999d116
commit c2440f9e53
3 changed files with 46 additions and 7 deletions

View File

@@ -12,7 +12,6 @@ from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
@@ -378,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
def sample_images(*args, **kwargs):
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)

View File

@@ -31,6 +31,7 @@ import hashlib
import subprocess
from io import BytesIO
import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
@@ -912,6 +913,23 @@ class BaseDataset(torch.utils.data.Dataset):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
# # run in parallel
# max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes)
# with ThreadPoolExecutor(max_workers) as executor:
# futures = []
# for info in tqdm(self.image_data.values(), desc="loading image sizes"):
# if info.image_size is None:
# def get_and_set_image_size(info):
# info.image_size = self.get_image_size(info.absolute_path)
# futures.append(executor.submit(get_and_set_image_size, info))
# # consume futures to reduce memory usage and prevent Ctrl-C hang
# if len(futures) >= max_workers:
# for future in futures:
# future.result()
# futures = []
# for future in futures:
# future.result()
if self.enable_bucket:
logger.info("make buckets")
else:
@@ -1826,7 +1844,7 @@ class DreamBoothDataset(BaseDataset):
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
for img_path in tqdm(img_paths, desc="read caption"):
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(