From 99338a204fd17925428fc35bdec7bdac4284b232 Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Sat, 25 Jan 2025 10:13:27 -0500 Subject: [PATCH] Revert "Create basic Flux calc for test and validation loss" This reverts commit 0b50630e6117af7c6ee8453b103136feb31bf7eb. --- flux_train.py | 264 +++++++++++++++++------------------------- library/train_util.py | 27 ----- 2 files changed, 107 insertions(+), 184 deletions(-) diff --git a/flux_train.py b/flux_train.py index 6fad0c31..2c1ec8f3 100644 --- a/flux_train.py +++ b/flux_train.py @@ -19,7 +19,6 @@ from multiprocessing import Value import time from typing import List, Optional, Tuple, Union import toml -import random from tqdm import tqdm @@ -45,8 +44,6 @@ logger = logging.getLogger(__name__) import library.config_util as config_util -from contextlib import nullcontext - # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( ConfigSanitizer, @@ -54,6 +51,7 @@ from library.config_util import ( ) from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -579,123 +577,6 @@ def train(args): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - - ### PLACEHOLDERS ### - test_step_freq = 10 - val_step_freq = 25 - test_set_count = 5 - val_set_count = 5 - test_val_repeat_count = 2 - - logger.warning('CREATING TEST AND VALIDATION SETS') - test_set, val_set = train_util.create_test_val_set(train_dataloader, test_set_count, val_set_count) - - # TODO: Get arguments for step_freq values - # TODO: Get arguments for test_set_count, test_noise_iter - - def calculate_loss(step=step, batch=batch, state=None, accumulate_loss: bool=True, accelerator=accelerator): - - if state is not None: - noise, noisy_model_input, timesteps, sigmas = state - - with accelerator.accumulate(*training_models) if accumulate_loss else nullcontext(): # Only utilize the accumulate context if loss is marked to be accumulated, otherwise, just use a null context. This avoids the test and validation samples impacting the training. - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) - else: - with torch.no_grad(): - # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encoder_conds = text_encoder_outputs_list - else: - # not cached or training, so get from text encoders - tokens_and_masks = batch["input_ids_list"] - with torch.no_grad(): - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( - flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask - ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - - # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps - - bsz = latents.shape[0] - - # get noisy model input and timesteps - if state is None: # Only calculate if not using stored values for validation - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype - ) - - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # call model - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - if args.bypass_flux_guidance: - flux_utils.bypass_flux_guidance(flux) - - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = flux( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - if args.bypass_flux_guidance: - flux_utils.restore_flux_guidance(flux) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # calculate loss - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - loss = loss.mean() - - state = (noise, noisy_model_input, timesteps, sigmas) - - return loss, state - loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -706,53 +587,122 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): - if step in val_set['steps']: # Skip validation steps, don't increment global step - logger.warning('SKIPPING BATCH IN VALIDATION SET') - continue - current_step.value = global_step if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - # CALCULATE LOSS ON TEST SET AT TEST SET FREQUENCY - if global_step==0: - test_fixed_states = [] - test_losses = [] - if global_step % test_step_freq == 0 and test_step_freq > 0: - test_loss, test_fixed_states = train_util.calc_test_val_loss(dataset=test_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=test_fixed_states, test=True) - test_losses.append(test_loss) + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) - # CALCULATE LOSS ON VALIDATION SET AT TEST SET FREQUENCY - if global_step==0: - val_fixed_states = [] - val_losses = [] - if global_step % val_step_freq == 0 and val_step_freq > 0: - val_loss, val_fixed_states = train_util.calc_test_val_loss(dataset=val_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=val_fixed_states, test=False) - val_losses.append(val_loss) + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) - # STANDARD LOSS CALCULATION - loss, _ = calculate_loss(step, batch, accumulate_loss=True) # Loss should be accumulated when not running the test/validation samples though + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # backward - accelerator.backward(loss) + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps - if not (args.fused_backward_pass or args.blockwise_fused_optimizers): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - else: - # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook - lr_scheduler.step() - if args.blockwise_fused_optimizers: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/library/train_util.py b/library/train_util.py index 465952d0..72b5b24d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6398,30 +6398,3 @@ class LossRecorder: @property def moving_average(self) -> float: return self.loss_total / len(self.loss_list) - -def calc_test_val_loss(dataset, loss_func, repeat_count, fixed_states=[], test=True): - test_val_ind = 'TEST' if test else 'VALIDATION' - # logger.warning(f'CALCULATING {test_val_ind} LOSS') - losses = [] - for step, batch in enumerate(dataset['batches'] * repeat_count): - if len(fixed_states) < len(dataset['batches']) * repeat_count: # If accumulating fixed states, calculate state as normal and return - loss, state = loss_func(step, batch, None, accumulate_loss=False) - fixed_states.append(state) - else: # Otherwise, recall the stored values and use those instead so the test loss is consistently calculated for each sample - state = fixed_states[step] - loss, _ = loss_func(step, batch, state, accumulate_loss=False) - losses.append(loss.detach().item()) - avg_loss = sum(losses) / len(losses) - logger.info(f'AVERAGE {test_val_ind} LOSS: {avg_loss:.6f}') - return avg_loss, fixed_states - -def create_test_val_set(dataloader, test_set_count, val_set_count): - test_set = test_set = {'steps':list(range(test_set_count)), 'batches':[]} - val_set = {'steps':list(range(test_set_count,test_set_count+val_set_count)), 'batches':[]} - for step, batch in enumerate(dataloader): - if step in test_set['steps']: - test_set['batches'].append(batch) - if step in val_set['steps']: - val_set['batches'].append(batch) - if step >= test_set_count + val_set_count: - return test_set, val_set \ No newline at end of file