mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'dev' into dataset-cache
This commit is contained in:
@@ -257,9 +257,10 @@ class ConfigSanitizer:
|
||||
}
|
||||
|
||||
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
||||
assert (
|
||||
support_dreambooth or support_finetuning or support_controlnet
|
||||
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
|
||||
assert support_dreambooth or support_finetuning or support_controlnet, (
|
||||
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
||||
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
||||
)
|
||||
|
||||
self.db_subset_schema = self.__merge_dict(
|
||||
self.SUBSET_ASCENDABLE_SCHEMA,
|
||||
@@ -326,7 +327,10 @@ class ConfigSanitizer:
|
||||
|
||||
self.dataset_schema = validate_flex_dataset
|
||||
elif support_dreambooth:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
if support_controlnet:
|
||||
self.dataset_schema = self.cn_dataset_schema
|
||||
else:
|
||||
self.dataset_schema = self.db_dataset_schema
|
||||
elif support_finetuning:
|
||||
self.dataset_schema = self.ft_dataset_schema
|
||||
elif support_controlnet:
|
||||
|
||||
@@ -3,11 +3,14 @@ import argparse
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
@@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
||||
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
||||
else:
|
||||
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
||||
loss = loss * snr_weight
|
||||
@@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
weight = 1/torch.sqrt(snr_t)
|
||||
weight = 1 / torch.sqrt(snr_t)
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
@@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
return noise
|
||||
|
||||
|
||||
def apply_masked_loss(loss, batch):
|
||||
# mask image is -1 to 1. we need to convert it to 0 to 1
|
||||
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
||||
|
||||
# resize to the same size as the loss
|
||||
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
||||
mask_image = mask_image / 2 + 0.5
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
139
library/deepspeed_utils.py
Normal file
139
library/deepspeed_utils.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
||||
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
||||
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
||||
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_device",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_nvme_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_device",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_nvme_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_init_flag",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
||||
"Only applicable with ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_save_16bit_model",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16_master_weights_and_gradients",
|
||||
action="store_true",
|
||||
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
||||
)
|
||||
|
||||
|
||||
def prepare_deepspeed_args(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
return
|
||||
|
||||
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||
args.max_data_loader_n_workers = 1
|
||||
|
||||
|
||||
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
if not args.deepspeed:
|
||||
return None
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
zero_stage=args.zero_stage,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_clipping=args.max_grad_norm,
|
||||
offload_optimizer_device=args.offload_optimizer_device,
|
||||
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
||||
offload_param_device=args.offload_param_device,
|
||||
offload_param_nvme_path=args.offload_param_nvme_path,
|
||||
zero3_init_flag=args.zero3_init_flag,
|
||||
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||
)
|
||||
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
||||
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
||||
logger.info("[DeepSpeed] full fp16 enable.")
|
||||
else:
|
||||
logger.info(
|
||||
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
||||
)
|
||||
|
||||
if args.offload_optimizer_device is not None:
|
||||
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
||||
logger.info("[DeepSpeed] building cpu_adam done.")
|
||||
|
||||
return deepspeed_plugin
|
||||
|
||||
|
||||
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
||||
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
# remove None from models
|
||||
models = {k: v for k, v in models.items() if v is not None}
|
||||
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
@@ -24,7 +24,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
|
||||
def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
# load models for each process
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
|
||||
@@ -70,6 +70,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
||||
import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -1882,6 +1883,9 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
assert (
|
||||
not subset.random_crop
|
||||
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
|
||||
db_subset = DreamBoothSubset(
|
||||
subset.image_dir,
|
||||
False,
|
||||
@@ -1933,7 +1937,7 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
cond_imgs_with_img = set()
|
||||
cond_imgs_with_pair = set()
|
||||
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
|
||||
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
|
||||
subset = None
|
||||
@@ -1947,23 +1951,29 @@ class ControlNetDataset(BaseDataset):
|
||||
logger.warning(f"not directory: {subset.conditioning_data_dir}")
|
||||
continue
|
||||
|
||||
img_basename = os.path.basename(info.absolute_path)
|
||||
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
|
||||
if not os.path.exists(ctrl_img_path):
|
||||
img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0]
|
||||
ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename)
|
||||
if len(ctrl_img_path) < 1:
|
||||
missing_imgs.append(img_basename)
|
||||
continue
|
||||
ctrl_img_path = ctrl_img_path[0]
|
||||
ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path
|
||||
|
||||
info.cond_img_path = ctrl_img_path
|
||||
cond_imgs_with_img.add(ctrl_img_path)
|
||||
cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive
|
||||
|
||||
extra_imgs = []
|
||||
for subset in subsets:
|
||||
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
|
||||
extra_imgs.extend(
|
||||
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
|
||||
)
|
||||
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
|
||||
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
|
||||
|
||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||
assert (
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
assert (
|
||||
len(extra_imgs) == 0
|
||||
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
@@ -3097,6 +3107,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
|
||||
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
@@ -3159,6 +3170,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--noise_offset",
|
||||
type=float,
|
||||
@@ -3332,6 +3344,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
|
||||
|
||||
def add_masked_loss_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masked_loss",
|
||||
action="store_true",
|
||||
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
|
||||
)
|
||||
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
r"""
|
||||
Verify training arguments. Also reflect highvram option to global variable
|
||||
@@ -4150,6 +4176,10 @@ def load_tokenizer(args: argparse.Namespace):
|
||||
|
||||
|
||||
def prepare_accelerator(args: argparse.Namespace):
|
||||
"""
|
||||
this function also prepares deepspeed plugin
|
||||
"""
|
||||
|
||||
if args.logging_dir is None:
|
||||
logging_dir = None
|
||||
else:
|
||||
@@ -4195,6 +4225,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
),
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -4202,6 +4234,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
@@ -4272,7 +4305,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
@@ -4283,7 +4315,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
accelerator.device if args.lowram else "cpu",
|
||||
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
@@ -4292,7 +4323,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user