mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
make deepspeed_utils
This commit is contained in:
33
fine_tune.py
33
fine_tune.py
@@ -10,7 +10,9 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from library import deepspeed_utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -42,6 +44,7 @@ from library.custom_train_functions import (
|
|||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
deepspeed_utils.prepare_deepspeed_args(args)
|
||||||
setup_logging(args, reset=True)
|
setup_logging(args, reset=True)
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
@@ -219,7 +222,7 @@ def train(args):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -248,21 +251,16 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
training_models_dict = {}
|
|
||||||
training_models_dict["unet"] = unet
|
|
||||||
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder
|
|
||||||
|
|
||||||
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
|
||||||
|
|
||||||
training_models = []
|
|
||||||
unet = ds_model.models["unet"]
|
|
||||||
training_models.append(unet)
|
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
text_encoder = ds_model.models["text_encoder"]
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||||
training_models.append(text_encoder)
|
else:
|
||||||
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
training_models = [ds_model]
|
||||||
|
else:
|
||||||
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
@@ -327,13 +325,13 @@ def train(args):
|
|||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(*training_models):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
|
||||||
else:
|
else:
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
@@ -493,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, False, True, True)
|
train_util.add_dataset_arguments(parser, False, True, True)
|
||||||
train_util.add_training_arguments(parser, False)
|
train_util.add_training_arguments(parser, False)
|
||||||
|
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
train_util.add_optimizer_arguments(parser)
|
train_util.add_optimizer_arguments(parser)
|
||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
|
|||||||
139
library/deepspeed_utils.py
Normal file
139
library/deepspeed_utils.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from accelerate import DeepSpeedPlugin, Accelerator
|
||||||
|
|
||||||
|
from .utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
||||||
|
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
||||||
|
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
||||||
|
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_optimizer_device",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=[None, "cpu", "nvme"],
|
||||||
|
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_optimizer_nvme_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_param_device",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=[None, "cpu", "nvme"],
|
||||||
|
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_param_nvme_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zero3_init_flag",
|
||||||
|
action="store_true",
|
||||||
|
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
||||||
|
"Only applicable with ZeRO Stage-3.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zero3_save_16bit_model",
|
||||||
|
action="store_true",
|
||||||
|
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16_master_weights_and_gradients",
|
||||||
|
action="store_true",
|
||||||
|
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_deepspeed_args(args: argparse.Namespace):
|
||||||
|
if not args.deepspeed:
|
||||||
|
return
|
||||||
|
|
||||||
|
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||||
|
args.max_data_loader_n_workers = 1
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||||
|
if not args.deepspeed:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import deepspeed
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(
|
||||||
|
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
deepspeed_plugin = DeepSpeedPlugin(
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
gradient_clipping=args.max_grad_norm,
|
||||||
|
offload_optimizer_device=args.offload_optimizer_device,
|
||||||
|
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
||||||
|
offload_param_device=args.offload_param_device,
|
||||||
|
offload_param_nvme_path=args.offload_param_nvme_path,
|
||||||
|
zero3_init_flag=args.zero3_init_flag,
|
||||||
|
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||||
|
)
|
||||||
|
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
||||||
|
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||||
|
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||||
|
)
|
||||||
|
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||||
|
if args.mixed_precision.lower() == "fp16":
|
||||||
|
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||||
|
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
||||||
|
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
||||||
|
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
||||||
|
logger.info("[DeepSpeed] full fp16 enable.")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.offload_optimizer_device is not None:
|
||||||
|
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
||||||
|
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
||||||
|
logger.info("[DeepSpeed] building cpu_adam done.")
|
||||||
|
|
||||||
|
return deepspeed_plugin
|
||||||
|
|
||||||
|
|
||||||
|
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
||||||
|
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||||
|
# remove None from models
|
||||||
|
models = {k: v for k, v in models.items() if v is not None}
|
||||||
|
|
||||||
|
class DeepSpeedWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, **kw_models) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.models = torch.nn.ModuleDict()
|
||||||
|
|
||||||
|
for key, model in kw_models.items():
|
||||||
|
if isinstance(model, list):
|
||||||
|
model = torch.nn.ModuleList(model)
|
||||||
|
assert isinstance(
|
||||||
|
model, torch.nn.Module
|
||||||
|
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||||
|
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
return self.models
|
||||||
|
|
||||||
|
ds_model = DeepSpeedWrapper(**models)
|
||||||
|
return ds_model
|
||||||
@@ -21,7 +21,6 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||||
from accelerate import DeepSpeedPlugin
|
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -70,6 +69,7 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
|||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
import library.sai_model_spec as sai_model_spec
|
import library.sai_model_spec as sai_model_spec
|
||||||
|
import library.deepspeed_utils as deepspeed_utils
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
@@ -3243,52 +3243,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
||||||
)
|
)
|
||||||
|
|
||||||
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
|
||||||
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
|
||||||
parser.add_argument(
|
|
||||||
"--zero_stage",
|
|
||||||
type=int, default=2,
|
|
||||||
choices=[0, 1, 2, 3],
|
|
||||||
help="Possible options are 0,1,2,3."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--offload_optimizer_device",
|
|
||||||
type=str, default=None,
|
|
||||||
choices=[None, "cpu", "nvme"],
|
|
||||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--offload_optimizer_nvme_path",
|
|
||||||
type=str, default=None,
|
|
||||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--offload_param_device",
|
|
||||||
type=str, default=None,
|
|
||||||
choices=[None, "cpu", "nvme"],
|
|
||||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--offload_param_nvme_path",
|
|
||||||
type=str, default=None,
|
|
||||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--zero3_init_flag",
|
|
||||||
action="store_true",
|
|
||||||
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
|
||||||
"Only applicable with ZeRO Stage-3."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--zero3_save_16bit_model",
|
|
||||||
action="store_true",
|
|
||||||
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--fp16_master_weights_and_gradients",
|
|
||||||
action="store_true",
|
|
||||||
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."
|
|
||||||
)
|
|
||||||
|
|
||||||
def verify_training_args(args: argparse.Namespace):
|
def verify_training_args(args: argparse.Namespace):
|
||||||
r"""
|
r"""
|
||||||
@@ -4090,6 +4044,10 @@ def load_tokenizer(args: argparse.Namespace):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_accelerator(args: argparse.Namespace):
|
def prepare_accelerator(args: argparse.Namespace):
|
||||||
|
"""
|
||||||
|
this function also prepares deepspeed plugin
|
||||||
|
"""
|
||||||
|
|
||||||
if args.logging_dir is None:
|
if args.logging_dir is None:
|
||||||
logging_dir = None
|
logging_dir = None
|
||||||
else:
|
else:
|
||||||
@@ -4135,7 +4093,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||||
deepspeed_plugin = prepare_deepspeed_plugin(args)
|
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||||
|
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
@@ -4149,62 +4107,6 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
print("accelerator device:", accelerator.device)
|
print("accelerator device:", accelerator.device)
|
||||||
return accelerator
|
return accelerator
|
||||||
|
|
||||||
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
|
||||||
if args.deepspeed is None: return None
|
|
||||||
try:
|
|
||||||
import deepspeed
|
|
||||||
except ImportError as e:
|
|
||||||
print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
deepspeed_plugin = DeepSpeedPlugin(
|
|
||||||
zero_stage=args.zero_stage,
|
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
|
|
||||||
offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
|
||||||
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
|
|
||||||
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
|
|
||||||
)
|
|
||||||
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
|
||||||
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
|
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
|
|
||||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
|
||||||
if args.mixed_precision.lower() == "fp16":
|
|
||||||
deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow.
|
|
||||||
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
|
||||||
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
|
||||||
deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True
|
|
||||||
print("[DeepSpeed] full fp16 enable.")
|
|
||||||
else:
|
|
||||||
print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.")
|
|
||||||
|
|
||||||
if args.offload_optimizer_device is not None:
|
|
||||||
print('[DeepSpeed] start to manually build cpu_adam.')
|
|
||||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
|
||||||
print('[DeepSpeed] building cpu_adam done.')
|
|
||||||
|
|
||||||
return deepspeed_plugin
|
|
||||||
|
|
||||||
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
|
||||||
class DeepSpeedWrapper(torch.nn.Module):
|
|
||||||
def __init__(self, **kw_models) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.models = torch.nn.ModuleDict()
|
|
||||||
|
|
||||||
for key, model in kw_models.items():
|
|
||||||
if isinstance(model, list):
|
|
||||||
model = torch.nn.ModuleList(model)
|
|
||||||
assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
|
||||||
self.models.update(
|
|
||||||
torch.nn.ModuleDict(
|
|
||||||
{key: model}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
return self.models
|
|
||||||
|
|
||||||
ds_model = DeepSpeedWrapper(**models)
|
|
||||||
return ds_model
|
|
||||||
|
|
||||||
def prepare_dtype(args: argparse.Namespace):
|
def prepare_dtype(args: argparse.Namespace):
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
|
|||||||
@@ -11,11 +11,12 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import sdxl_model_util
|
from library import deepspeed_utils, sdxl_model_util
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
@@ -97,6 +98,7 @@ def train(args):
|
|||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
sdxl_train_util.verify_sdxl_training_args(args)
|
sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
deepspeed_utils.prepare_deepspeed_args(args)
|
||||||
setup_logging(args, reset=True)
|
setup_logging(args, reset=True)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@@ -361,7 +363,7 @@ def train(args):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -398,41 +400,31 @@ def train(args):
|
|||||||
text_encoder1.to(weight_dtype)
|
text_encoder1.to(weight_dtype)
|
||||||
text_encoder2.to(weight_dtype)
|
text_encoder2.to(weight_dtype)
|
||||||
|
|
||||||
|
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||||
|
if train_text_encoder1:
|
||||||
|
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
|
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
training_models_dict = {}
|
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||||
if train_unet:
|
args,
|
||||||
training_models_dict["unet"] = unet
|
unet=unet if train_unet else None,
|
||||||
if train_text_encoder1:
|
text_encoder1=text_encoder1 if train_text_encoder1 else None,
|
||||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
text_encoder2=text_encoder2 if train_text_encoder2 else None,
|
||||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
)
|
||||||
training_models_dict["text_encoder1"] = text_encoder1
|
ds_model = accelerator.prepare(ds_model)
|
||||||
if train_text_encoder2:
|
training_models = [ds_model]
|
||||||
training_models_dict["text_encoder2"] = text_encoder2
|
|
||||||
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
|
||||||
|
|
||||||
training_models = [] # override training_models
|
else:
|
||||||
if train_unet:
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
unet = ds_model.models["unet"]
|
|
||||||
training_models.append(unet)
|
|
||||||
if train_text_encoder1:
|
|
||||||
text_encoder1 = ds_model.models["text_encoder1"]
|
|
||||||
training_models.append(text_encoder1)
|
|
||||||
if train_text_encoder2:
|
|
||||||
text_encoder2 = ds_model.models["text_encoder2"]
|
|
||||||
training_models.append(text_encoder2)
|
|
||||||
|
|
||||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
|
||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
if train_text_encoder1:
|
if train_text_encoder1:
|
||||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
|
||||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
|
||||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
|
||||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||||
if train_text_encoder2:
|
if train_text_encoder2:
|
||||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
|
||||||
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
@@ -446,8 +438,9 @@ def train(args):
|
|||||||
text_encoder2.to(accelerator.device)
|
text_encoder2.to(accelerator.device)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16 and not args.deepspeed:
|
if args.full_fp16:
|
||||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||||
|
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
# resumeする
|
# resumeする
|
||||||
@@ -508,10 +501,10 @@ def train(args):
|
|||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(*training_models):
|
with accelerator.accumulate(*training_models):
|
||||||
with torch.no_grad(): # why this block differ within train_network.py?
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
else:
|
||||||
else:
|
with torch.no_grad():
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||||
|
|
||||||
@@ -519,7 +512,7 @@ def train(args):
|
|||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||||
|
|
||||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||||
input_ids1 = batch["input_ids"]
|
input_ids1 = batch["input_ids"]
|
||||||
@@ -768,6 +761,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, True, True)
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
train_util.add_training_arguments(parser, False)
|
train_util.add_training_arguments(parser, False)
|
||||||
|
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
train_util.add_optimizer_arguments(parser)
|
train_util.add_optimizer_arguments(parser)
|
||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
|
|||||||
37
train_db.py
37
train_db.py
@@ -11,7 +11,9 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from library import deepspeed_utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
@@ -46,6 +48,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, False)
|
train_util.prepare_dataset_args(args, False)
|
||||||
|
deepspeed_utils.prepare_deepspeed_args(args)
|
||||||
setup_logging(args, reset=True)
|
setup_logging(args, reset=True)
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
@@ -187,7 +190,7 @@ def train(args):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -220,30 +223,27 @@ def train(args):
|
|||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
training_models_dict = {}
|
if args.train_text_encoder:
|
||||||
training_models_dict["unet"] = unet
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||||
if train_text_encoder: training_models_dict["text_encoder"] = text_encoder
|
else:
|
||||||
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||||
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
training_models = []
|
training_models = [ds_model]
|
||||||
unet = ds_model.models["unet"]
|
|
||||||
training_models.append(unet)
|
|
||||||
if train_text_encoder:
|
|
||||||
text_encoder = ds_model.models["text_encoder"]
|
|
||||||
training_models.append(text_encoder)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
training_models = [unet, text_encoder]
|
||||||
else:
|
else:
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
training_models = [unet]
|
||||||
|
|
||||||
if not train_text_encoder:
|
if not train_text_encoder:
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@@ -312,8 +312,10 @@ def train(args):
|
|||||||
if not args.gradient_checkpointing:
|
if not args.gradient_checkpointing:
|
||||||
text_encoder.train(False)
|
text_encoder.train(False)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
|
if len(training_models) == 2:
|
||||||
|
training_models = training_models[0] # remove text_encoder from training_models
|
||||||
|
|
||||||
with accelerator.accumulate(unet):
|
with accelerator.accumulate(*training_models):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# latentに変換
|
# latentに変換
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
@@ -480,6 +482,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, False, True)
|
train_util.add_dataset_arguments(parser, True, False, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||||
train_util.add_sd_saving_arguments(parser)
|
train_util.add_sd_saving_arguments(parser)
|
||||||
train_util.add_optimizer_arguments(parser)
|
train_util.add_optimizer_arguments(parser)
|
||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
|
|||||||
@@ -13,13 +13,14 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import model_util
|
from library import deepspeed_utils, model_util
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
from library.train_util import (
|
from library.train_util import (
|
||||||
@@ -141,6 +142,7 @@ class NetworkTrainer:
|
|||||||
training_started_at = time.time()
|
training_started_at = time.time()
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
deepspeed_utils.prepare_deepspeed_args(args)
|
||||||
setup_logging(args, reset=True)
|
setup_logging(args, reset=True)
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
@@ -357,7 +359,7 @@ class NetworkTrainer:
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -414,22 +416,17 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
training_models_dict = {}
|
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||||
if train_unet: training_models_dict["unet"] = unet
|
args,
|
||||||
if train_text_encoder: training_models_dict["text_encoder"] = text_encoders
|
unet=unet if train_unet else None,
|
||||||
training_models_dict["network"] = network
|
text_encoder1=text_encoders[0] if train_text_encoder else None,
|
||||||
|
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
||||||
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
network=network,
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
)
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
if train_unet: unet = ds_model.models["unet"]
|
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||||
if train_text_encoder:
|
)
|
||||||
text_encoder = ds_model.models["text_encoder"]
|
training_model = ds_model
|
||||||
if len(ds_model.models["text_encoder"]) > 1:
|
|
||||||
text_encoders = text_encoder
|
|
||||||
else:
|
|
||||||
text_encoders = [text_encoder]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
@@ -444,7 +441,10 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||||
|
|
||||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
network, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
training_model = network
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
# according to TI example in Diffusers, train is required
|
# according to TI example in Diffusers, train is required
|
||||||
@@ -777,13 +777,13 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(training_model):
|
||||||
on_step_start(text_encoder, unet)
|
on_step_start(text_encoder, unet)
|
||||||
|
|
||||||
with torch.no_grad():
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
latents = batch["latents"].to(accelerator.device)
|
||||||
latents = batch["latents"].to(accelerator.device)
|
else:
|
||||||
else:
|
with torch.no_grad():
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
|
||||||
|
|
||||||
@@ -791,7 +791,7 @@ class NetworkTrainer:
|
|||||||
if torch.any(torch.isnan(latents)):
|
if torch.any(torch.isnan(latents)):
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
latents = latents * self.vae_scale_factor
|
latents = latents * self.vae_scale_factor
|
||||||
|
|
||||||
# get multiplier for each sample
|
# get multiplier for each sample
|
||||||
if network_has_multiplier:
|
if network_has_multiplier:
|
||||||
@@ -976,6 +976,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
train_util.add_sd_models_arguments(parser)
|
train_util.add_sd_models_arguments(parser)
|
||||||
train_util.add_dataset_arguments(parser, True, True, True)
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||||
train_util.add_optimizer_arguments(parser)
|
train_util.add_optimizer_arguments(parser)
|
||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
|||||||
Reference in New Issue
Block a user