add FLUX.1 LoRA training

This commit is contained in:
Kohya S
2024-08-09 22:56:48 +09:00
parent da4d0fe016
commit 36b2e6fc28
10 changed files with 2992 additions and 55 deletions

View File

@@ -100,6 +100,12 @@ class NetworkTrainer:
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
def get_tokenize_strategy(self, args):
@@ -147,6 +153,81 @@ class NetworkTrainer:
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoder, unet)
# region SD/SDXL
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images).latent_dist.sample()
def shift_scale_latents(self, args, latents):
return latents * self.vae_scale_factor
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
return noise_pred, target, timesteps, huber_c, None
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
return loss
# endregion
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -253,11 +334,6 @@ class NetworkTrainer:
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
@@ -445,16 +521,19 @@ class NetworkTrainer:
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.to(accelerator.device) # this makes faster `to(dtype)` below
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype)
unet.to(dtype=unet_weight_dtype) # this takes long time and large memory
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
if hasattr(t_enc.text_model, "embeddings"):
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
@@ -851,12 +930,7 @@ class NetworkTrainer:
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
if accelerator.is_main_process:
init_kwargs = {}
@@ -913,6 +987,13 @@ class NetworkTrainer:
initial_step -= len(train_dataloader)
global_step = initial_step
# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
for t_enc in text_encoders:
logger.info(f"text_encoder dtype: {te_weight_dtype}, device: {t_enc.device}")
clean_memory_on_device(accelerator.device)
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -940,13 +1021,15 @@ class NetworkTrainer:
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype))
latents = latents.to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
latents = self.shift_scale_latents(args, latents)
# get multiplier for each sample
if network_has_multiplier:
@@ -985,41 +1068,25 @@ class NetworkTrainer:
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
# sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -1027,14 +1094,8 @@ class NetworkTrainer:
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
# min snr gamma, scale v pred loss like noise pred, v pred like loss, debiased estimation etc.
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし