From bbf6bbd5ea27231066cec98b8bf2a65f162cb18f Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:48:38 -0500 Subject: [PATCH] Use self.get_noise_pred_and_target and drop fixed timesteps --- flux_train_network.py | 7 ++- sd3_train_network.py | 3 +- train_network.py | 116 ++++++++++++------------------------------ 3 files changed, 40 insertions(+), 86 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975ba..b3aebecc 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -339,6 +339,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -375,7 +376,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): # if not args.split_mode: # normal forward - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = unet( img=img, @@ -420,7 +421,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): intermediate_txt.requires_grad_(True) vec.requires_grad_(True) pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + with torch.set_grad_enabled(is_train and train_unet): + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) """ return model_pred diff --git a/sd3_train_network.py b/sd3_train_network.py index fb7711bd..c7417802 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -312,6 +312,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): network, weight_dtype, train_unet, + is_train=True ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -339,7 +340,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): t5_attn_mask = None # call model - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): # TODO support attention mask model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) diff --git a/train_network.py b/train_network.py index 377ddf48..61e6369a 100644 --- a/train_network.py +++ b/train_network.py @@ -223,6 +223,7 @@ class NetworkTrainer: network, weight_dtype, train_unet, + is_train=True ): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified @@ -236,7 +237,7 @@ class NetworkTrainer: t.requires_grad_(True) # Predict the noise residual - with accelerator.autocast(): + with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): noise_pred = self.call_unet( args, accelerator, @@ -317,7 +318,7 @@ class NetworkTrainer: # endregion - def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor: + def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor: with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -372,91 +373,40 @@ class NetworkTrainer: batch_size = latents.shape[0] - # Sample noise, - noise = train_util.make_noise(args, latents) - def pick_timesteps_list() -> torch.IntTensor: - if timesteps_list is None or timesteps_list == []: - return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1)) - else: - return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) + # Predict the noise residual + # and add noise to the latents + # with noise offset and/or multires noise if specified - chosen_timesteps_list = pick_timesteps_list() - total_loss = torch.zeros((batch_size, 1)).to(latents.device) + # sample noise, call unet, get target + noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train + ) - # Use input timesteps_list or use described timesteps above - for fixed_timesteps in chosen_timesteps_list: - fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) - # Predict the noise residual - # and add noise to the latents - # with noise offset and/or multires noise if specified - noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights - # ensure the hidden state will require grad - if args.gradient_checkpointing: - for x in noisy_latents: - x.requires_grad_(True) - for t in text_encoder_conds: - t.requires_grad_(True) + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): - noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps) - else: - target = noise - - # differential output preservation - if "custom_attributes" in batch: - diff_output_pr_indices = [] - for i, custom_attributes in enumerate(batch["custom_attributes"]): - if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: - diff_output_pr_indices.append(i) - - if len(diff_output_pr_indices) > 0: - network.set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): - noise_pred_prior = self.call_unet( - args, - accelerator, - unet, - noisy_latents, - fixed_timesteps, - text_encoder_conds, - batch, - weight_dtype, - indices=diff_output_pr_indices, - ) - network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step - target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - - huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし - - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights - - loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler) - - total_loss += loss - - return total_loss / len(chosen_timesteps_list) + return loss.mean() def train(self, args): session_id = random.randint(0, 2**32) @@ -1416,7 +1366,7 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item()) val_progress_bar.update(1) @@ -1447,7 +1397,7 @@ class NetworkTrainer: if val_step >= validation_steps: break - loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990]) + loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)