Merge branch 'dev' into dataset-cache

This commit is contained in:
Kohya S
2024-03-26 19:43:40 +09:00
22 changed files with 534 additions and 277 deletions

View File

@@ -70,6 +70,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging
setup_logging()
@@ -1882,6 +1883,9 @@ class ControlNetDataset(BaseDataset):
db_subsets = []
for subset in subsets:
assert (
not subset.random_crop
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
db_subset = DreamBoothSubset(
subset.image_dir,
False,
@@ -1933,7 +1937,7 @@ class ControlNetDataset(BaseDataset):
# assert all conditioning data exists
missing_imgs = []
cond_imgs_with_img = set()
cond_imgs_with_pair = set()
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
subset = None
@@ -1947,23 +1951,29 @@ class ControlNetDataset(BaseDataset):
logger.warning(f"not directory: {subset.conditioning_data_dir}")
continue
img_basename = os.path.basename(info.absolute_path)
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
if not os.path.exists(ctrl_img_path):
img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0]
ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename)
if len(ctrl_img_path) < 1:
missing_imgs.append(img_basename)
continue
ctrl_img_path = ctrl_img_path[0]
ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path
info.cond_img_path = ctrl_img_path
cond_imgs_with_img.add(ctrl_img_path)
cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive
extra_imgs = []
for subset in subsets:
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
extra_imgs.extend(
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
self.conditioning_image_transforms = IMAGE_TRANSFORMS
@@ -3097,6 +3107,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--ddp_timeout",
type=int,
@@ -3159,6 +3170,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインするオプション",
)
parser.add_argument(
"--noise_offset",
type=float,
@@ -3332,6 +3344,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
def add_masked_loss_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--masked_loss",
action="store_true",
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)
def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
@@ -4150,6 +4176,10 @@ def load_tokenizer(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace):
"""
this function also prepares deepspeed plugin
"""
if args.logging_dir is None:
logging_dir = None
else:
@@ -4195,6 +4225,8 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -4202,6 +4234,7 @@ def prepare_accelerator(args: argparse.Namespace):
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
print("accelerator device:", accelerator.device)
return accelerator
@@ -4272,7 +4305,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
@@ -4283,7 +4315,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
@@ -4292,7 +4323,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format