mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix to call train/eval in schedulefree #1605
This commit is contained in:
@@ -11,6 +11,9 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Sep 18, 2024 (update 1):
|
||||||
|
Fixed an issue where train()/eval() was not called properly with the schedule-free optimizer. The schedule-free optimizer can be used in FLUX.1 LoRA training and fine-tuning for now.
|
||||||
|
|
||||||
Sep 18, 2024:
|
Sep 18, 2024:
|
||||||
|
|
||||||
- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details.
|
- Schedule-free optimizer is added. Thanks to sdbds! See PR [#1600](https://github.com/kohya-ss/sd-scripts/pull/1600) for details.
|
||||||
|
|||||||
@@ -347,8 +347,13 @@ def train(args):
|
|||||||
|
|
||||||
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||||
|
|
||||||
|
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||||
|
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||||
|
optimizer_train_fn = lambda: None # dummy function
|
||||||
|
optimizer_eval_fn = lambda: None # dummy function
|
||||||
else:
|
else:
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||||
|
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||||
|
|
||||||
# prepare dataloader
|
# prepare dataloader
|
||||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
@@ -760,6 +765,7 @@ def train(args):
|
|||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
optimizer_eval_fn()
|
||||||
flux_train_utils.sample_images(
|
flux_train_utils.sample_images(
|
||||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||||
)
|
)
|
||||||
@@ -778,6 +784,7 @@ def train(args):
|
|||||||
global_step,
|
global_step,
|
||||||
accelerator.unwrap_model(flux),
|
accelerator.unwrap_model(flux),
|
||||||
)
|
)
|
||||||
|
optimizer_train_fn()
|
||||||
|
|
||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
if len(accelerator.trackers) > 0:
|
if len(accelerator.trackers) > 0:
|
||||||
@@ -800,6 +807,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
optimizer_eval_fn()
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||||
@@ -816,12 +824,14 @@ def train(args):
|
|||||||
flux_train_utils.sample_images(
|
flux_train_utils.sample_images(
|
||||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||||
)
|
)
|
||||||
|
optimizer_train_fn()
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
# if is_main_process:
|
# if is_main_process:
|
||||||
flux = accelerator.unwrap_model(flux)
|
flux = accelerator.unwrap_model(flux)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
optimizer_eval_fn()
|
||||||
|
|
||||||
if args.save_state or args.save_state_on_train_end:
|
if args.save_state or args.save_state_on_train_end:
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
NamedTuple,
|
NamedTuple,
|
||||||
@@ -4715,6 +4716,18 @@ def get_optimizer(args, trainable_params):
|
|||||||
return optimizer_name, optimizer_args, optimizer
|
return optimizer_name, optimizer_args, optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer_train_eval_fn(optimizer: Optimizer, args: argparse.Namespace) -> Tuple[Callable, Callable]:
|
||||||
|
if not is_schedulefree_optimizer(optimizer, args):
|
||||||
|
# return dummy func
|
||||||
|
return lambda: None, lambda: None
|
||||||
|
|
||||||
|
# get train and eval functions from optimizer
|
||||||
|
train_fn = optimizer.train
|
||||||
|
eval_fn = optimizer.eval
|
||||||
|
|
||||||
|
return train_fn, eval_fn
|
||||||
|
|
||||||
|
|
||||||
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
|
def is_schedulefree_optimizer(optimizer: Optimizer, args: argparse.Namespace) -> bool:
|
||||||
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
return args.optimizer_type.lower().endswith("schedulefree".lower()) # or args.optimizer_schedulefree_wrapper
|
||||||
|
|
||||||
|
|||||||
@@ -498,6 +498,7 @@ class NetworkTrainer:
|
|||||||
# accelerator.print(f"trainable_params: {k} = {v}")
|
# accelerator.print(f"trainable_params: {k} = {v}")
|
||||||
|
|
||||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
|
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||||
|
|
||||||
# prepare dataloader
|
# prepare dataloader
|
||||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
@@ -1199,6 +1200,7 @@ class NetworkTrainer:
|
|||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
optimizer_eval_fn()
|
||||||
self.sample_images(
|
self.sample_images(
|
||||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
|
||||||
)
|
)
|
||||||
@@ -1217,6 +1219,7 @@ class NetworkTrainer:
|
|||||||
if remove_step_no is not None:
|
if remove_step_no is not None:
|
||||||
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
||||||
remove_model(remove_ckpt_name)
|
remove_model(remove_ckpt_name)
|
||||||
|
optimizer_train_fn()
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
@@ -1243,6 +1246,7 @@ class NetworkTrainer:
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# 指定エポックごとにモデルを保存
|
# 指定エポックごとにモデルを保存
|
||||||
|
optimizer_eval_fn()
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||||
if is_main_process and saving:
|
if is_main_process and saving:
|
||||||
@@ -1258,6 +1262,7 @@ class NetworkTrainer:
|
|||||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||||
|
|
||||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||||
|
optimizer_train_fn()
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
@@ -1268,6 +1273,7 @@ class NetworkTrainer:
|
|||||||
network = accelerator.unwrap_model(network)
|
network = accelerator.unwrap_model(network)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
optimizer_eval_fn()
|
||||||
|
|
||||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|||||||
Reference in New Issue
Block a user