mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support deepspeed
This commit is contained in:
41
fine_tune.py
41
fine_tune.py
@@ -102,6 +102,7 @@ def train(args):
|
|||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
@@ -152,7 +153,7 @@ def train(args):
|
|||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -187,7 +188,7 @@ def train(args):
|
|||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.requires_grad_(True)
|
m.requires_grad_(True)
|
||||||
@@ -214,7 +215,7 @@ def train(args):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -240,13 +241,33 @@ def train(args):
|
|||||||
unet.to(weight_dtype)
|
unet.to(weight_dtype)
|
||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
if args.deepspeed:
|
||||||
if args.train_text_encoder:
|
# wrapping model
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
class DeepSpeedModel(torch.nn.Module):
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
def __init__(self, unet, text_encoder, vae) -> None:
|
||||||
)
|
super().__init__()
|
||||||
else:
|
self.unet = unet
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||||
|
self.vae = vae
|
||||||
|
def get_models(self):
|
||||||
|
return self.unet, self.text_encoders, self.vae
|
||||||
|
|
||||||
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
||||||
|
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||||
|
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
||||||
|
vae.to(vae_dtype)
|
||||||
|
text_encoder = text_encoders
|
||||||
|
|
||||||
|
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
if args.train_text_encoder:
|
||||||
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||||
|
from accelerate import DeepSpeedPlugin
|
||||||
import gc
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@@ -3124,6 +3125,47 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
||||||
|
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
||||||
|
parser.add_argument(
|
||||||
|
"--zero_stage",
|
||||||
|
type=int, default=2,
|
||||||
|
choices=[0, 1, 2, 3],
|
||||||
|
help="Possible options are 0,1,2,3."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_optimizer",
|
||||||
|
type=str, default=None,
|
||||||
|
choices=[None, "cpu", "nvme"],
|
||||||
|
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_optimizer_nvme_path",
|
||||||
|
type=str, default=None,
|
||||||
|
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_param_device",
|
||||||
|
type=str, default=None,
|
||||||
|
choices=[None, "cpu", "nvme"],
|
||||||
|
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_param_nvme_path",
|
||||||
|
type=str, default=None,
|
||||||
|
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zero3_init_flag",
|
||||||
|
action="store_true",
|
||||||
|
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
||||||
|
"Only applicable with ZeRO Stage-3."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zero3_save_16bit_model",
|
||||||
|
action="store_true",
|
||||||
|
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
|
||||||
|
)
|
||||||
|
|
||||||
def verify_training_args(args: argparse.Namespace):
|
def verify_training_args(args: argparse.Namespace):
|
||||||
if args.v_parameterization and not args.v2:
|
if args.v_parameterization and not args.v2:
|
||||||
@@ -3912,6 +3954,17 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||||
|
deepspeed_plugin = None
|
||||||
|
if args.deepspeed:
|
||||||
|
deepspeed_plugin = DeepSpeedPlugin(
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
|
||||||
|
offload_optimizer=args.offload_optimizer, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
||||||
|
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
|
||||||
|
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||||
|
)
|
||||||
|
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
||||||
|
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
@@ -3919,6 +3972,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
project_dir=logging_dir,
|
project_dir=logging_dir,
|
||||||
kwargs_handlers=kwargs_handlers,
|
kwargs_handlers=kwargs_handlers,
|
||||||
dynamo_backend=dynamo_backend,
|
dynamo_backend=dynamo_backend,
|
||||||
|
deepspeed_plugin=deepspeed_plugin,
|
||||||
)
|
)
|
||||||
return accelerator
|
return accelerator
|
||||||
|
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ def train(args):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -389,18 +389,37 @@ def train(args):
|
|||||||
text_encoder1.to(weight_dtype)
|
text_encoder1.to(weight_dtype)
|
||||||
text_encoder2.to(weight_dtype)
|
text_encoder2.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
if args.deepspeed:
|
||||||
if train_unet:
|
# Wrapping model for DeepSpeed
|
||||||
unet = accelerator.prepare(unet)
|
class DeepSpeedModel(torch.nn.Module):
|
||||||
if train_text_encoder1:
|
def __init__(self, unet, text_encoder, vae) -> None:
|
||||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
super().__init__()
|
||||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
self.unet = unet
|
||||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
self.vae = vae
|
||||||
if train_text_encoder2:
|
|
||||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
def get_models(self):
|
||||||
|
return self.unet, self.text_encoders, self.vae
|
||||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
text_encoders = [text_encoder1, text_encoder2]
|
||||||
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
||||||
|
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||||
|
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
||||||
|
vae.to(vae_dtype) # to avoid explicitly half-vae
|
||||||
|
text_encoder1, text_encoder2 = text_encoders[0], text_encoders[1]
|
||||||
|
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
if train_unet:
|
||||||
|
unet = accelerator.prepare(unet)
|
||||||
|
if train_text_encoder1:
|
||||||
|
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||||
|
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
|
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||||
|
if train_text_encoder2:
|
||||||
|
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
|
|||||||
39
train_db.py
39
train_db.py
@@ -184,7 +184,7 @@ def train(args):
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -214,15 +214,36 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if train_text_encoder:
|
if args.deepspeed:
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
# wrapping model
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
class DeepSpeedModel(torch.nn.Module):
|
||||||
)
|
def __init__(self, unet, text_encoder, vae) -> None:
|
||||||
else:
|
super().__init__()
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
self.unet = unet
|
||||||
|
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||||
|
self.vae = vae
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
return self.unet, self.text_encoders, self.vae
|
||||||
|
|
||||||
if not train_text_encoder:
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
||||||
|
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||||
|
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
||||||
|
vae.to(vae_dtype) # to avoid explicitly half-vae
|
||||||
|
text_encoder = text_encoders
|
||||||
|
else:
|
||||||
|
if train_text_encoder:
|
||||||
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
if not train_text_encoder:
|
||||||
|
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
|
|||||||
@@ -353,18 +353,26 @@ class NetworkTrainer:
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collator,
|
collate_fn=collator,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
if args.deepspeed:
|
||||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
)
|
len(train_dataloader) / args.gradient_accumulation_steps
|
||||||
accelerator.print(
|
)
|
||||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
accelerator.print(
|
||||||
)
|
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
accelerator.print(
|
||||||
|
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||||
|
)
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
@@ -409,20 +417,42 @@ class NetworkTrainer:
|
|||||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
if train_unet:
|
if args.deepspeed:
|
||||||
unet = accelerator.prepare(unet)
|
# wrapping model
|
||||||
|
class DeepSpeedModel(torch.nn.Module):
|
||||||
|
def __init__(self, unet, text_encoder, vae, network) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.unet = unet
|
||||||
|
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||||
|
self.vae = vae
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
return self.unet, self.text_encoders, self.vae, self.network
|
||||||
|
|
||||||
|
unet.to(accelerator.device, dtype=unet_weight_dtype)
|
||||||
|
[t_enc.to(accelerator.device, dtype=te_weight_dtype) for t_enc in text_encoders]
|
||||||
|
ds_model = DeepSpeedModel(unet, text_encoders, vae, network)
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||||
|
unet, text_encoders, vae, network = ds_model.get_models() # for compatiblility
|
||||||
|
vae.to(vae_dtype) # to avoid explicitly half-vae
|
||||||
|
text_encoder = text_encoders
|
||||||
else:
|
else:
|
||||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
if train_unet:
|
||||||
if train_text_encoder:
|
unet = accelerator.prepare(unet)
|
||||||
if len(text_encoders) > 1:
|
|
||||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
|
||||||
else:
|
else:
|
||||||
text_encoder = accelerator.prepare(text_encoder)
|
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||||
text_encoders = [text_encoder]
|
if train_text_encoder:
|
||||||
else:
|
if len(text_encoders) > 1:
|
||||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||||
|
else:
|
||||||
|
text_encoder = accelerator.prepare(text_encoder)
|
||||||
|
text_encoders = [text_encoder]
|
||||||
|
else:
|
||||||
|
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||||
|
|
||||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
# according to TI example in Diffusers, train is required
|
# according to TI example in Diffusers, train is required
|
||||||
|
|||||||
Reference in New Issue
Block a user