mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
4 Commits
main
...
scheduler-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f33e155c5b | ||
|
|
c1ef6dcabc | ||
|
|
5fe9ded188 | ||
|
|
c68712635c |
42
fine_tune.py
42
fine_tune.py
@@ -250,23 +250,32 @@ def train(args):
|
|||||||
unet.to(weight_dtype)
|
unet.to(weight_dtype)
|
||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||||
else:
|
else:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
if not use_schedule_free_optimizer:
|
||||||
)
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
training_models = [ds_model]
|
training_models = [ds_model]
|
||||||
else:
|
else:
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||||
|
if not use_schedule_free_optimizer:
|
||||||
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: optimizer.train()
|
||||||
|
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@@ -324,6 +333,7 @@ def train(args):
|
|||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(*training_models):
|
with accelerator.accumulate(*training_models):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -354,7 +364,9 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -368,7 +380,9 @@ def train(args):
|
|||||||
|
|
||||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
@@ -380,7 +394,9 @@ def train(args):
|
|||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
@@ -390,9 +406,11 @@ def train(args):
|
|||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
@@ -471,7 +489,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|||||||
@@ -3087,7 +3087,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
|
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gradient_accumulation_steps",
|
"--gradient_accumulation_steps",
|
||||||
@@ -4088,6 +4088,21 @@ def get_optimizer(args, trainable_params):
|
|||||||
optimizer_class = torch.optim.AdamW
|
optimizer_class = torch.optim.AdamW
|
||||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
|
elif optimizer_type.endswith("schedulefree".lower()):
|
||||||
|
try:
|
||||||
|
import schedulefree as sf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||||
|
if optimizer_type == "AdamWScheduleFree".lower():
|
||||||
|
optimizer_class = sf.AdamWScheduleFree
|
||||||
|
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
|
||||||
|
elif optimizer_type == "SGDScheduleFree".lower():
|
||||||
|
optimizer_class = sf.SGDScheduleFree
|
||||||
|
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||||
|
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
# 任意のoptimizerを使う
|
# 任意のoptimizerを使う
|
||||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||||
@@ -4116,6 +4131,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
"""
|
"""
|
||||||
Unified API to get any scheduler from its name.
|
Unified API to get any scheduler from its name.
|
||||||
"""
|
"""
|
||||||
|
# supports schedule free optimizer
|
||||||
|
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||||
|
# return dummy scheduler: it has 'step' method but does nothing
|
||||||
|
logger.info("use dummy scheduler for schedule free optimizer / schedule free optimizer用のダミースケジューラを使用します")
|
||||||
|
lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.CONSTANT](optimizer)
|
||||||
|
lr_scheduler.step = lambda: None
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
name = args.lr_scheduler
|
name = args.lr_scheduler
|
||||||
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||||
@@ -4250,7 +4273,7 @@ def load_tokenizer(args: argparse.Namespace):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def prepare_accelerator(args: argparse.Namespace):
|
def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
|
||||||
"""
|
"""
|
||||||
this function also prepares deepspeed plugin
|
this function also prepares deepspeed plugin
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
accelerate==0.25.0
|
accelerate==0.30.0
|
||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
diffusers[torch]==0.25.0
|
diffusers[torch]==0.25.0
|
||||||
ftfy==6.1.1
|
ftfy==6.1.1
|
||||||
@@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
|
|||||||
bitsandbytes==0.43.0
|
bitsandbytes==0.43.0
|
||||||
prodigyopt==1.0
|
prodigyopt==1.0
|
||||||
lion-pytorch==0.0.6
|
lion-pytorch==0.0.6
|
||||||
|
schedulefree==1.2.5
|
||||||
tensorboard
|
tensorboard
|
||||||
safetensors==0.4.2
|
safetensors==0.4.2
|
||||||
# gradio==3.16.2
|
# gradio==3.16.2
|
||||||
|
|||||||
@@ -407,6 +407,7 @@ def train(args):
|
|||||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
|
||||||
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||||
args,
|
args,
|
||||||
@@ -415,9 +416,9 @@ def train(args):
|
|||||||
text_encoder2=text_encoder2 if train_text_encoder2 else None,
|
text_encoder2=text_encoder2 if train_text_encoder2 else None,
|
||||||
)
|
)
|
||||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
if not use_schedule_free_optimizer:
|
||||||
)
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
training_models = [ds_model]
|
training_models = [ds_model]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -428,7 +429,17 @@ def train(args):
|
|||||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||||
if train_text_encoder2:
|
if train_text_encoder2:
|
||||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
if not use_schedule_free_optimizer:
|
||||||
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
|
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: optimizer.train()
|
||||||
|
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
@@ -503,6 +514,7 @@ def train(args):
|
|||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(*training_models):
|
with accelerator.accumulate(*training_models):
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
@@ -582,7 +594,9 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -600,7 +614,9 @@ def train(args):
|
|||||||
or args.masked_loss
|
or args.masked_loss
|
||||||
):
|
):
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -616,7 +632,9 @@ def train(args):
|
|||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
@@ -626,9 +644,11 @@ def train(args):
|
|||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
@@ -736,7 +756,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state or args.save_state_on_train_end:
|
if args.save_state or args.save_state_on_train_end:
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -286,7 +287,18 @@ def train(args):
|
|||||||
unet.to(weight_dtype)
|
unet.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
|
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||||
|
if not use_schedule_free_optimizer:
|
||||||
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: optimizer.train()
|
||||||
|
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||||
@@ -390,6 +402,7 @@ def train(args):
|
|||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(unet):
|
with accelerator.accumulate(unet):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -439,7 +452,9 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -458,7 +473,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
@@ -484,6 +501,8 @@ def train(args):
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -254,9 +255,19 @@ def train(args):
|
|||||||
network.to(weight_dtype)
|
network.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
unet, network, optimizer, train_dataloader = accelerator.prepare(unet, network, optimizer, train_dataloader)
|
||||||
)
|
if not use_schedule_free_optimizer:
|
||||||
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: optimizer.train()
|
||||||
|
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
network: control_net_lllite.ControlNetLLLite
|
network: control_net_lllite.ControlNetLLLite
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@@ -357,6 +368,7 @@ def train(args):
|
|||||||
network.on_epoch_start() # train()
|
network.on_epoch_start() # train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -406,7 +418,9 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||||
|
|
||||||
@@ -426,7 +440,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
@@ -452,6 +468,8 @@ def train(args):
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from tqdm import tqdm
|
|||||||
import torch
|
import torch
|
||||||
from library import deepspeed_utils
|
from library import deepspeed_utils
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@@ -226,7 +227,7 @@ def train(args):
|
|||||||
)
|
)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@@ -276,9 +277,18 @@ def train(args):
|
|||||||
controlnet.to(weight_dtype)
|
controlnet.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
controlnet, optimizer, train_dataloader, lr_scheduler
|
controlnet, optimizer, train_dataloader = accelerator.prepare(controlnet, optimizer, train_dataloader)
|
||||||
)
|
if not use_schedule_free_optimizer:
|
||||||
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: optimizer.train()
|
||||||
|
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
@@ -393,6 +403,7 @@ def train(args):
|
|||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(controlnet):
|
with accelerator.accumulate(controlnet):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -420,7 +431,9 @@ def train(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
|
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
|
||||||
|
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
|
||||||
|
)
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
@@ -452,7 +465,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
@@ -472,6 +487,8 @@ def train(args):
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
34
train_db.py
34
train_db.py
@@ -224,25 +224,34 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||||
else:
|
else:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
if not use_schedule_free_optimizer:
|
||||||
)
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
training_models = [ds_model]
|
training_models = [ds_model]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
training_models = [unet, text_encoder]
|
training_models = [unet, text_encoder]
|
||||||
else:
|
else:
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||||
training_models = [unet]
|
training_models = [unet]
|
||||||
|
if not use_schedule_free_optimizer:
|
||||||
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: optimizer.train()
|
||||||
|
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
if not train_text_encoder:
|
if not train_text_encoder:
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||||
@@ -307,6 +316,7 @@ def train(args):
|
|||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
# 指定したステップ数でText Encoderの学習を止める
|
# 指定したステップ数でText Encoderの学習を止める
|
||||||
if global_step == args.stop_text_encoder_training:
|
if global_step == args.stop_text_encoder_training:
|
||||||
@@ -346,7 +356,9 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -358,7 +370,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -387,6 +401,8 @@ def train(args):
|
|||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
@@ -412,6 +412,7 @@ class NetworkTrainer:
|
|||||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||||
args,
|
args,
|
||||||
@@ -420,9 +421,9 @@ class NetworkTrainer:
|
|||||||
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
||||||
network=network,
|
network=network,
|
||||||
)
|
)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
if not use_schedule_free_optimizer:
|
||||||
)
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
training_model = ds_model
|
training_model = ds_model
|
||||||
else:
|
else:
|
||||||
if train_unet:
|
if train_unet:
|
||||||
@@ -438,14 +439,23 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||||
|
|
||||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
network, optimizer, train_dataloader = accelerator.prepare(network, optimizer, train_dataloader)
|
||||||
network, optimizer, train_dataloader, lr_scheduler
|
if not use_schedule_free_optimizer:
|
||||||
)
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
training_model = network
|
training_model = network
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||||
|
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
# according to TI example in Diffusers, train is required
|
# according to TI example in Diffusers, train is required
|
||||||
unet.train()
|
unet.train()
|
||||||
|
|
||||||
for t_enc in text_encoders:
|
for t_enc in text_encoders:
|
||||||
t_enc.train()
|
t_enc.train()
|
||||||
|
|
||||||
@@ -804,6 +814,7 @@ class NetworkTrainer:
|
|||||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_model):
|
with accelerator.accumulate(training_model):
|
||||||
on_step_start(text_encoder, unet)
|
on_step_start(text_encoder, unet)
|
||||||
@@ -920,6 +931,8 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
@@ -415,20 +415,28 @@ class TextualInversionTrainer:
|
|||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
if len(text_encoders) == 1:
|
if len(text_encoders) == 1:
|
||||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet(
|
||||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
text_encoder_or_list, optimizer, train_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
elif len(text_encoders) == 2:
|
elif len(text_encoders) == 2:
|
||||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare(
|
||||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
text_encoders[0], text_encoders[1], optimizer, train_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
if not use_schedule_free_optimizer:
|
||||||
|
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||||
|
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
index_no_updates_list = []
|
index_no_updates_list = []
|
||||||
orig_embeds_params_list = []
|
orig_embeds_params_list = []
|
||||||
@@ -557,6 +565,7 @@ class TextualInversionTrainer:
|
|||||||
loss_total = 0
|
loss_total = 0
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(text_encoders[0]):
|
with accelerator.accumulate(text_encoders[0]):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -588,7 +597,9 @@ class TextualInversionTrainer:
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -627,6 +638,8 @@ class TextualInversionTrainer:
|
|||||||
index_no_updates
|
index_no_updates
|
||||||
]
|
]
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
@@ -335,9 +335,18 @@ def train(args):
|
|||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
|
||||||
)
|
if not use_schedule_free_optimizer:
|
||||||
|
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||||
|
|
||||||
|
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||||
|
if use_schedule_free_optimizer:
|
||||||
|
optimizer_train_if_needed = lambda: optimizer.train()
|
||||||
|
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||||
|
else:
|
||||||
|
optimizer_train_if_needed = lambda: None
|
||||||
|
optimizer_eval_if_needed = lambda: None
|
||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||||
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
|
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
@@ -438,6 +447,7 @@ def train(args):
|
|||||||
loss_total = 0
|
loss_total = 0
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
optimizer_train_if_needed()
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -461,7 +471,9 @@ def train(args):
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -473,7 +485,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
loss = train_util.conditional_loss(
|
||||||
|
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||||
|
)
|
||||||
if args.masked_loss:
|
if args.masked_loss:
|
||||||
loss = apply_masked_loss(loss, batch)
|
loss = apply_masked_loss(loss, batch)
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -505,6 +519,8 @@ def train(args):
|
|||||||
index_no_updates
|
index_no_updates
|
||||||
]
|
]
|
||||||
|
|
||||||
|
optimizer_eval_if_needed()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user