mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Compare commits
3 Commits
scheduler-
...
resume-ste
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
faadc350a4 | ||
|
|
6d9338f8b5 | ||
|
|
5f0eebaa56 |
42
fine_tune.py
42
fine_tune.py
@@ -250,32 +250,23 @@ def train(args):
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
if args.deepspeed:
|
||||
if args.train_text_encoder:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||
else:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
@@ -333,7 +324,6 @@ def train(args):
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad():
|
||||
@@ -364,9 +354,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
@@ -380,9 +368,7 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
@@ -394,9 +380,7 @@ def train(args):
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
@@ -406,11 +390,9 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -489,7 +471,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
@@ -649,8 +649,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
||||
self.shuffle_buckets()
|
||||
self.current_epoch = epoch
|
||||
if epoch > self.current_epoch:
|
||||
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||
num_epochs = epoch - self.current_epoch
|
||||
for _ in range(num_epochs):
|
||||
self.current_epoch += 1
|
||||
self.shuffle_buckets()
|
||||
else:
|
||||
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_current_step(self, step):
|
||||
self.current_step = step
|
||||
@@ -941,7 +948,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
def shuffle_buckets(self):
|
||||
# set random seed for this epoch
|
||||
# set random seed for this epoch: current_epoch is not incremented
|
||||
random.seed(self.seed + self.current_epoch)
|
||||
|
||||
random.shuffle(self.buckets_indices)
|
||||
@@ -2346,10 +2353,10 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
image = Image.open(image_path)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
with Image.open(image_path) as image:
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
@@ -3087,7 +3094,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
|
||||
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
@@ -4088,21 +4095,6 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("schedulefree".lower()):
|
||||
try:
|
||||
import schedulefree as sf
|
||||
except ImportError:
|
||||
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
|
||||
if optimizer_type == "AdamWScheduleFree".lower():
|
||||
optimizer_class = sf.AdamWScheduleFree
|
||||
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "SGDScheduleFree".lower():
|
||||
optimizer_class = sf.SGDScheduleFree
|
||||
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
@@ -4131,14 +4123,6 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
"""
|
||||
# supports schedule free optimizer
|
||||
if args.optimizer_type.lower().endswith("schedulefree"):
|
||||
# return dummy scheduler: it has 'step' method but does nothing
|
||||
logger.info("use dummy scheduler for schedule free optimizer / schedule free optimizer用のダミースケジューラを使用します")
|
||||
lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.CONSTANT](optimizer)
|
||||
lr_scheduler.step = lambda: None
|
||||
return lr_scheduler
|
||||
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
|
||||
@@ -4273,7 +4257,7 @@ def load_tokenizer(args: argparse.Namespace):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
|
||||
def prepare_accelerator(args: argparse.Namespace):
|
||||
"""
|
||||
this function also prepares deepspeed plugin
|
||||
"""
|
||||
@@ -5410,7 +5394,7 @@ class LossRecorder:
|
||||
self.loss_total: float = 0.0
|
||||
|
||||
def add(self, *, epoch: int, step: int, loss: float) -> None:
|
||||
if epoch == 0:
|
||||
if epoch == 0 or step >= len(self.loss_list):
|
||||
self.loss_list.append(loss)
|
||||
else:
|
||||
self.loss_total -= self.loss_list[step]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
accelerate==0.30.0
|
||||
accelerate==0.25.0
|
||||
transformers==4.36.2
|
||||
diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
@@ -9,7 +9,6 @@ pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.43.0
|
||||
prodigyopt==1.0
|
||||
lion-pytorch==0.0.6
|
||||
schedulefree==1.2.5
|
||||
tensorboard
|
||||
safetensors==0.4.2
|
||||
# gradio==3.16.2
|
||||
|
||||
@@ -407,7 +407,6 @@ def train(args):
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||
args,
|
||||
@@ -416,9 +415,9 @@ def train(args):
|
||||
text_encoder2=text_encoder2 if train_text_encoder2 else None,
|
||||
)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
@@ -429,17 +428,7 @@ def train(args):
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -514,7 +503,6 @@ def train(args):
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -594,9 +582,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
@@ -614,9 +600,7 @@ def train(args):
|
||||
or args.masked_loss
|
||||
):
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -632,9 +616,7 @@ def train(args):
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
@@ -644,11 +626,9 @@ def train(args):
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -756,7 +736,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
if args.save_state or args.save_state_on_train_end:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
@@ -15,7 +15,6 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -287,18 +286,7 @@ def train(args):
|
||||
unet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
@@ -402,7 +390,6 @@ def train(args):
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(unet):
|
||||
with torch.no_grad():
|
||||
@@ -452,9 +439,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
@@ -473,9 +458,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
@@ -501,8 +484,6 @@ def train(args):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -12,7 +12,6 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -255,19 +254,9 @@ def train(args):
|
||||
network.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
unet, network, optimizer, train_dataloader = accelerator.prepare(unet, network, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
network: control_net_lllite.ControlNetLLLite
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -368,7 +357,6 @@ def train(args):
|
||||
network.on_epoch_start() # train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
with torch.no_grad():
|
||||
@@ -418,9 +406,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
||||
|
||||
@@ -440,9 +426,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
@@ -468,8 +452,6 @@ def train(args):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -13,7 +13,6 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from library import deepspeed_utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -227,7 +226,7 @@ def train(args):
|
||||
)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -277,18 +276,9 @@ def train(args):
|
||||
controlnet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
controlnet, optimizer, train_dataloader = accelerator.prepare(controlnet, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
controlnet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -403,7 +393,6 @@ def train(args):
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(controlnet):
|
||||
with torch.no_grad():
|
||||
@@ -431,9 +420,7 @@ def train(args):
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
|
||||
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
|
||||
)
|
||||
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
@@ -465,9 +452,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
@@ -487,8 +472,6 @@ def train(args):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
34
train_db.py
34
train_db.py
@@ -224,34 +224,25 @@ def train(args):
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
if args.deepspeed:
|
||||
if args.train_text_encoder:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
||||
else:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
if train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_models = [unet, text_encoder]
|
||||
else:
|
||||
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
training_models = [unet]
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
if not train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||
@@ -316,7 +307,6 @@ def train(args):
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
@@ -356,9 +346,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
@@ -370,9 +358,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -401,8 +387,6 @@ def train(args):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
138
train_network.py
138
train_network.py
@@ -412,7 +412,6 @@ class NetworkTrainer:
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||
args,
|
||||
@@ -421,9 +420,9 @@ class NetworkTrainer:
|
||||
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
||||
network=network,
|
||||
)
|
||||
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = ds_model
|
||||
else:
|
||||
if train_unet:
|
||||
@@ -439,23 +438,14 @@ class NetworkTrainer:
|
||||
else:
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
network, optimizer, train_dataloader = accelerator.prepare(network, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
training_model = network
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
|
||||
for t_enc in text_encoders:
|
||||
t_enc.train()
|
||||
|
||||
@@ -484,17 +474,24 @@ class NetworkTrainer:
|
||||
# before resuming make hook for saving/loading to save/load the network weights only
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# pop weights of other models than network to save only network weights
|
||||
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
|
||||
if accelerator.is_main_process or args.deepspeed:
|
||||
if accelerator.is_main_process:
|
||||
remove_indices = []
|
||||
for i, model in enumerate(models):
|
||||
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
||||
remove_indices.append(i)
|
||||
for i in reversed(remove_indices):
|
||||
if len(weights) > i:
|
||||
weights.pop(i)
|
||||
weights.pop(i)
|
||||
# print(f"save model hook: {len(weights)} weights will be saved")
|
||||
|
||||
# save current ecpoch and step
|
||||
train_state_file = os.path.join(output_dir, "train_state.json")
|
||||
# +1 is needed because the state is saved before current_step is set from global_step
|
||||
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
|
||||
with open(train_state_file, "w", encoding="utf-8") as f:
|
||||
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
|
||||
|
||||
steps_from_state = None
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
# remove models except network
|
||||
remove_indices = []
|
||||
@@ -505,6 +502,15 @@ class NetworkTrainer:
|
||||
models.pop(i)
|
||||
# print(f"load model hook: {len(models)} models will be loaded")
|
||||
|
||||
# load current epoch and step to
|
||||
nonlocal steps_from_state
|
||||
train_state_file = os.path.join(input_dir, "train_state.json")
|
||||
if os.path.exists(train_state_file):
|
||||
with open(train_state_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
steps_from_state = data["current_step"] + 1 # because
|
||||
logger.info(f"load train state from {train_state_file}: {data}")
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
@@ -748,7 +754,52 @@ class NetworkTrainer:
|
||||
if key in metadata:
|
||||
minimum_metadata[key] = metadata[key]
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
# calculate steps to skip when resuming or starting from a specific step
|
||||
initial_step = 0
|
||||
if args.initial_epoch is not None or args.initial_step is not None:
|
||||
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
|
||||
if steps_from_state is not None:
|
||||
logger.warning(
|
||||
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
|
||||
)
|
||||
if args.initial_step is not None:
|
||||
initial_step = args.initial_step
|
||||
else:
|
||||
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
|
||||
initial_step = (args.initial_epoch - 1) * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
else:
|
||||
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
|
||||
if steps_from_state is not None:
|
||||
initial_step = steps_from_state
|
||||
steps_from_state = None
|
||||
|
||||
if initial_step > 0:
|
||||
assert (
|
||||
args.max_train_steps > initial_step
|
||||
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||
)
|
||||
|
||||
epoch_to_start = 0
|
||||
if initial_step > 0:
|
||||
if args.skip_until_initial_step:
|
||||
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
|
||||
if not args.resume:
|
||||
logger.info(
|
||||
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
|
||||
)
|
||||
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
||||
else:
|
||||
# if not, only epoch no is skipped for informative purpose
|
||||
epoch_to_start = initial_step // math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
initial_step = 0 # do not skip
|
||||
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
@@ -805,17 +856,35 @@ class NetworkTrainer:
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if initial_step > 0:
|
||||
# set starting global step calculated from initial_step. because skipping steps doesn't increment global_step
|
||||
global_step = initial_step // (accelerator.num_processes * args.gradient_accumulation_steps)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
if initial_step > steps_per_epoch:
|
||||
logger.info(f"skipping epoch {epoch+1} because initial_step (multiplied) is {initial_step}")
|
||||
initial_step -= steps_per_epoch
|
||||
continue
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
active_dataloader = train_dataloader
|
||||
if initial_step > 0:
|
||||
logger.info(f"skipping {initial_step} batches in epoch {epoch+1}")
|
||||
active_dataloader = accelerator.skip_first_batches(
|
||||
train_dataloader, initial_step * args.gradient_accumulation_steps
|
||||
)
|
||||
initial_step = 0
|
||||
|
||||
for step, batch in enumerate(active_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
|
||||
@@ -931,8 +1000,6 @@ class NetworkTrainer:
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -1116,6 +1183,25 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_until_initial_step",
|
||||
action="store_true",
|
||||
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_epoch",
|
||||
type=int,
|
||||
default=None,
|
||||
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
|
||||
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial_step",
|
||||
type=int,
|
||||
default=None,
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -415,28 +415,20 @@ class TextualInversionTrainer:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
if len(text_encoders) == 1:
|
||||
text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet(
|
||||
text_encoder_or_list, optimizer, train_dataloader
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
elif len(text_encoders) == 2:
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if not use_schedule_free_optimizer:
|
||||
optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).train()
|
||||
optimizer_eval_if_needed = lambda: (optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer).eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
|
||||
index_no_updates_list = []
|
||||
orig_embeds_params_list = []
|
||||
@@ -565,7 +557,6 @@ class TextualInversionTrainer:
|
||||
loss_total = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(text_encoders[0]):
|
||||
with torch.no_grad():
|
||||
@@ -597,9 +588,7 @@ class TextualInversionTrainer:
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -638,8 +627,6 @@ class TextualInversionTrainer:
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -335,18 +335,9 @@ def train(args):
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
|
||||
text_encoder, optimizer, train_dataloader = accelerator.prepare(text_encoder, optimizer, train_dataloader)
|
||||
if not use_schedule_free_optimizer:
|
||||
lr_scheduler = accelerator.prepare(lr_scheduler)
|
||||
|
||||
# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
|
||||
if use_schedule_free_optimizer:
|
||||
optimizer_train_if_needed = lambda: optimizer.train()
|
||||
optimizer_eval_if_needed = lambda: optimizer.eval()
|
||||
else:
|
||||
optimizer_train_if_needed = lambda: None
|
||||
optimizer_eval_if_needed = lambda: None
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
|
||||
@@ -447,7 +438,6 @@ def train(args):
|
||||
loss_total = 0
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer_train_if_needed()
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(text_encoder):
|
||||
with torch.no_grad():
|
||||
@@ -471,9 +461,7 @@ def train(args):
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
@@ -485,9 +473,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = train_util.conditional_loss(
|
||||
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -519,8 +505,6 @@ def train(args):
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
optimizer_eval_if_needed()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
|
||||
Reference in New Issue
Block a user