mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into dataset-cache
This commit is contained in:
@@ -70,6 +70,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -1882,6 +1883,9 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
assert (
|
||||
not subset.random_crop
|
||||
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
|
||||
db_subset = DreamBoothSubset(
|
||||
subset.image_dir,
|
||||
False,
|
||||
@@ -1933,7 +1937,7 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
cond_imgs_with_img = set()
|
||||
cond_imgs_with_pair = set()
|
||||
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
|
||||
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
|
||||
subset = None
|
||||
@@ -1947,23 +1951,29 @@ class ControlNetDataset(BaseDataset):
|
||||
logger.warning(f"not directory: {subset.conditioning_data_dir}")
|
||||
continue
|
||||
|
||||
img_basename = os.path.basename(info.absolute_path)
|
||||
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
|
||||
if not os.path.exists(ctrl_img_path):
|
||||
img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0]
|
||||
ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename)
|
||||
if len(ctrl_img_path) < 1:
|
||||
missing_imgs.append(img_basename)
|
||||
continue
|
||||
ctrl_img_path = ctrl_img_path[0]
|
||||
ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path
|
||||
|
||||
info.cond_img_path = ctrl_img_path
|
||||
cond_imgs_with_img.add(ctrl_img_path)
|
||||
cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive
|
||||
|
||||
extra_imgs = []
|
||||
for subset in subsets:
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
extra_imgs.extend(
|
||||
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
|
||||
)
|
||||
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
|
||||
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
|
||||
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
assert (
|
||||
len(extra_imgs) == 0
|
||||
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -3097,6 +3107,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
|
||||
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
@@ -3159,6 +3170,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--noise_offset",
|
||||
type=float,
|
||||
@@ -3332,6 +3344,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
|
||||
|
||||
def add_masked_loss_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masked_loss",
|
||||
action="store_true",
|
||||
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||
)
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
r"""
|
||||
Verify training arguments. Also reflect highvram option to global variable
|
||||
@@ -4150,6 +4176,10 @@ def load_tokenizer(args: argparse.Namespace):
|
||||
|
||||
|
||||
def prepare_accelerator(args: argparse.Namespace):
|
||||
"""
|
||||
this function also prepares deepspeed plugin
|
||||
"""
|
||||
|
||||
if args.logging_dir is None:
|
||||
logging_dir = None
|
||||
else:
|
||||
@@ -4195,6 +4225,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
),
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -4202,6 +4234,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
@@ -4272,7 +4305,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
@@ -4283,7 +4315,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
@@ -4292,7 +4323,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user