mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Support new optimizer Schedule free (#1250)
* init * use no schedule * fix typo * update for eval() * fix typo
This commit is contained in:
35
fine_tune.py
35
fine_tune.py
@@ -255,18 +255,31 @@ def train(args):
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||
else:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
@@ -324,6 +337,8 @@ def train(args):
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad():
|
||||
@@ -390,9 +405,13 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -3087,7 +3087,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
|
||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
@@ -4087,6 +4087,21 @@ def get_optimizer(args, trainable_params):
|
||||
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
if optimizer_type == "AdamWScheduleFree".lower():
|
||||
optimizer_class = sf.AdamWScheduleFree
|
||||
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "SGDScheduleFree".lower():
|
||||
optimizer_class = sf.SGDScheduleFree
|
||||
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
|
||||
@@ -415,9 +415,14 @@ def train(args):
|
||||
text_encoder2=text_encoder2 if train_text_encoder2 else None,
|
||||
)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
@@ -428,7 +433,10 @@ def train(args):
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
|
||||
else:
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -503,6 +511,8 @@ def train(args):
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -626,9 +636,13 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -286,11 +286,19 @@ def train(args):
|
||||
unet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
|
||||
else:
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
unet.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
@@ -390,6 +398,8 @@ def train(args):
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(unet):
|
||||
with torch.no_grad():
|
||||
@@ -481,9 +491,13 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -254,15 +254,24 @@ def train(args):
|
||||
network.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, network, optimizer, train_dataloader = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
network: control_net_lllite.ControlNetLLLite
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
else:
|
||||
unet.eval()
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
network.prepare_grad_etc()
|
||||
|
||||
@@ -357,6 +366,8 @@ def train(args):
|
||||
network.on_epoch_start() # train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
with torch.no_grad():
|
||||
@@ -449,9 +460,13 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -276,9 +276,14 @@ def train(args):
|
||||
controlnet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
controlnet, optimizer, train_dataloader = accelerator.prepare(
|
||||
controlnet, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -393,6 +398,8 @@ def train(args):
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(controlnet):
|
||||
with torch.no_grad():
|
||||
@@ -472,6 +479,9 @@ def train(args):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
35
train_db.py
35
train_db.py
@@ -229,19 +229,32 @@ def train(args):
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||
else:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
if train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [unet, text_encoder]
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
training_models = [unet]
|
||||
|
||||
if not train_text_encoder:
|
||||
@@ -307,6 +320,8 @@ def train(args):
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
@@ -384,9 +399,13 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -420,9 +420,14 @@ class NetworkTrainer:
|
||||
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
||||
network=network,
|
||||
)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = ds_model
|
||||
else:
|
||||
if train_unet:
|
||||
@@ -437,15 +442,23 @@ class NetworkTrainer:
|
||||
text_encoders = [text_encoder]
|
||||
else:
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
network, optimizer, train_dataloader = accelerator.prepare(
|
||||
network, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = network
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
unet.train()
|
||||
|
||||
for t_enc in text_encoders:
|
||||
t_enc.train()
|
||||
|
||||
@@ -454,6 +467,8 @@ class NetworkTrainer:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
else:
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
unet.eval()
|
||||
for t_enc in text_encoders:
|
||||
t_enc.eval()
|
||||
@@ -804,6 +819,8 @@ class NetworkTrainer:
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
@@ -909,7 +926,8 @@ class NetworkTrainer:
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
@@ -920,6 +938,9 @@ class NetworkTrainer:
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -416,14 +416,24 @@ class TextualInversionTrainer:
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if len(text_encoders) == 1:
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet(
|
||||
text_encoder_or_list, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.preparet(
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
elif len(text_encoders) == 2:
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
@@ -452,8 +462,12 @@ class TextualInversionTrainer:
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
# TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
unet.train()
|
||||
else:
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
unet.eval()
|
||||
|
||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||
@@ -557,6 +571,8 @@ class TextualInversionTrainer:
|
||||
loss_total = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(text_encoders[0]):
|
||||
with torch.no_grad():
|
||||
@@ -627,6 +643,9 @@ class TextualInversionTrainer:
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -335,9 +335,14 @@ def train(args):
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
|
||||
@@ -354,8 +359,12 @@ def train(args):
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
else:
|
||||
unet.eval()
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
@@ -438,6 +447,8 @@ def train(args):
|
||||
loss_total = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(text_encoder):
|
||||
with torch.no_grad():
|
||||
@@ -496,7 +507,8 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
@@ -505,6 +517,9 @@ def train(args):
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
Reference in New Issue
Block a user