mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Support new optimizer Schedule free (#1250)
* init * use no schedule * fix typo * update for eval() * fix typo
This commit is contained in:
35
fine_tune.py
35
fine_tune.py
@@ -255,18 +255,31 @@ def train(args):
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||
else:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader
|
||||
)
|
||||
else:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
@@ -324,6 +337,8 @@ def train(args):
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad():
|
||||
@@ -390,9 +405,13 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
if not args.optimizer_type.lower().endswith("schedulefree"):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if (args.optimizer_type.lower().endswith("schedulefree")):
|
||||
optimizer.eval()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
Reference in New Issue
Block a user