simplify codes for schedule free optimizer

This commit is contained in:
Kohya S
2024-05-04 21:03:47 +09:00
parent c68712635c
commit 5fe9ded188
11 changed files with 211 additions and 202 deletions

View File

@@ -250,36 +250,32 @@ def train(args):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
if args.optimizer_type.lower().endswith("schedulefree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
else:
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -337,8 +333,7 @@ def train(args):
m.train()
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
@@ -369,7 +364,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -383,7 +380,9 @@ def train(args):
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
@@ -395,7 +394,9 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -405,12 +406,10 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
optimizer.zero_grad(set_to_none=True)
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -490,7 +489,7 @@ def train(args):
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す

View File

@@ -4087,17 +4087,17 @@ def get_optimizer(args, trainable_params):
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
@@ -4131,6 +4131,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
"""
# supports schedule free optimizer
if args.optimizer_type.lower().endswith("schedulefree"):
# return dummy scheduler: it has 'step' method but does nothing
logger.info("use dummy scheduler for schedule free optimizer / schedule free optimizer用のダミースケジューラを使用します")
lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.CONSTANT](optimizer)
lr_scheduler.step = lambda: None
return lr_scheduler
name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
@@ -4265,7 +4273,7 @@ def load_tokenizer(args: argparse.Namespace):
return tokenizer
def prepare_accelerator(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
"""
this function also prepares deepspeed plugin
"""

View File

@@ -1,4 +1,4 @@
accelerate==0.25.0
accelerate==0.29.2
transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1

View File

@@ -407,6 +407,7 @@ def train(args):
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
@@ -415,14 +416,9 @@ def train(args):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]
else:
@@ -433,10 +429,17 @@ def train(args):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
if args.optimizer_type.lower().endswith("schedulefree"):
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -511,8 +514,7 @@ def train(args):
m.train()
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
@@ -592,7 +594,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -610,7 +614,9 @@ def train(args):
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -626,7 +632,9 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -636,12 +644,10 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
optimizer.zero_grad(set_to_none=True)
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -750,7 +756,7 @@ def train(args):
accelerator.end_training()
if args.save_state or args.save_state_on_train_end:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す

View File

@@ -15,6 +15,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -286,19 +287,22 @@ def train(args):
unet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
if args.gradient_checkpointing:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
@@ -398,8 +402,7 @@ def train(args):
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
@@ -449,7 +452,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -468,7 +473,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -491,12 +498,10 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@@ -254,24 +254,27 @@ def train(args):
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.optimizer_type.lower().endswith("schedulefree"):
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
unet, network, optimizer, train_dataloader = accelerator.prepare(
unet, network, optimizer, train_dataloader
)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
network: control_net_lllite.ControlNetLLLite
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
network.prepare_grad_etc()
@@ -366,8 +369,7 @@ def train(args):
network.on_epoch_start() # train()
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
@@ -460,12 +462,10 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@@ -13,6 +13,7 @@ from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -226,7 +227,7 @@ def train(args):
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
@@ -276,14 +277,18 @@ def train(args):
controlnet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.optimizer_type.lower().endswith("schedulefree"):
controlnet, optimizer, train_dataloader = accelerator.prepare(
controlnet, optimizer, train_dataloader
)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
controlnet, optimizer, train_dataloader = accelerator.prepare(controlnet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -398,8 +403,7 @@ def train(args):
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
@@ -427,7 +431,9 @@ def train(args):
)
# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -459,7 +465,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -479,8 +487,7 @@ def train(args):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@@ -224,38 +224,36 @@ def train(args):
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]
else:
if train_text_encoder:
if args.optimizer_type.lower().endswith("schedulefree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
training_models = [unet, text_encoder]
else:
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
training_models = [unet]
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -320,8 +318,7 @@ def train(args):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
@@ -361,7 +358,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -373,7 +372,9 @@ def train(args):
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -399,12 +400,10 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@@ -412,6 +412,7 @@ class NetworkTrainer:
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
@@ -420,14 +421,9 @@ class NetworkTrainer:
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_model = ds_model
else:
if train_unet:
@@ -442,21 +438,22 @@ class NetworkTrainer:
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
if args.optimizer_type.lower().endswith("schedulefree"):
network, optimizer, train_dataloader = accelerator.prepare(
network, optimizer, train_dataloader
)
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
network, optimizer, train_dataloader = accelerator.prepare(network, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_model = network
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train()
for t_enc in text_encoders:
@@ -467,8 +464,6 @@ class NetworkTrainer:
t_enc.text_model.embeddings.requires_grad_(True)
else:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()
for t_enc in text_encoders:
t_enc.eval()
@@ -819,8 +814,7 @@ class NetworkTrainer:
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
@@ -926,8 +920,7 @@ class NetworkTrainer:
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
@@ -938,8 +931,7 @@ class NetworkTrainer:
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@@ -415,30 +415,28 @@ class TextualInversionTrainer:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if len(text_encoders) == 1:
if args.optimizer_type.lower().endswith("schedulefree"):
text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet(
text_encoder_or_list, optimizer, train_dataloader
)
else:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.preparet(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet(
text_encoder_or_list, optimizer, train_dataloader
)
elif len(text_encoders) == 2:
if args.optimizer_type.lower().endswith("schedulefree"):
text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader
)
else:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader
)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
else:
raise NotImplementedError()
if not use_schedule_free_optimizer:
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
index_no_updates_list = []
orig_embeds_params_list = []
@@ -462,12 +460,8 @@ class TextualInversionTrainer:
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
# TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train()
else:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
@@ -571,8 +565,7 @@ class TextualInversionTrainer:
loss_total = 0
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
@@ -604,7 +597,9 @@ class TextualInversionTrainer:
else:
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -643,8 +638,7 @@ class TextualInversionTrainer:
index_no_updates
]
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@@ -335,14 +335,20 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
if args.optimizer_type.lower().endswith("schedulefree"):
text_encoder, optimizer, train_dataloader = accelerator.prepare(
text_encoder, optimizer, train_dataloader
)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
text_encoder, optimizer, train_dataloader = accelerator.prepare(
text_encoder, optimizer, train_dataloader
)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
@@ -359,12 +365,8 @@ def train(args):
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
if not cache_latents:
vae.requires_grad_(False)
@@ -447,8 +449,7 @@ def train(args):
loss_total = 0
for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
@@ -507,8 +508,7 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
@@ -517,8 +517,7 @@ def train(args):
index_no_updates
]
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
optimizer_eval_if_needed()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: