diff --git a/flux_train.py b/flux_train.py index 2c1ec8f3..6fad0c31 100644 --- a/flux_train.py +++ b/flux_train.py @@ -19,6 +19,7 @@ from multiprocessing import Value import time from typing import List, Optional, Tuple, Union import toml +import random from tqdm import tqdm @@ -44,6 +45,8 @@ logger = logging.getLogger(__name__) import library.config_util as config_util +from contextlib import nullcontext + # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( ConfigSanitizer, @@ -51,7 +54,6 @@ from library.config_util import ( ) from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -577,6 +579,123 @@ def train(args): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) + + ### PLACEHOLDERS ### + test_step_freq = 10 + val_step_freq = 25 + test_set_count = 5 + val_set_count = 5 + test_val_repeat_count = 2 + + logger.warning('CREATING TEST AND VALIDATION SETS') + test_set, val_set = train_util.create_test_val_set(train_dataloader, test_set_count, val_set_count) + + # TODO: Get arguments for step_freq values + # TODO: Get arguments for test_set_count, test_noise_iter + + def calculate_loss(step=step, batch=batch, state=None, accumulate_loss: bool=True, accelerator=accelerator): + + if state is not None: + noise, noisy_model_input, timesteps, sigmas = state + + with accelerator.accumulate(*training_models) if accumulate_loss else nullcontext(): # Only utilize the accumulate context if loss is marked to be accumulated, otherwise, just use a null context. This avoids the test and validation samples impacting the training. + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + bsz = latents.shape[0] + + # get noisy model input and timesteps + if state is None: # Only calculate if not using stored values for validation + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + state = (noise, noisy_model_input, timesteps, sigmas) + + return loss, state + loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -587,122 +706,53 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): + if step in val_set['steps']: # Skip validation steps, don't increment global step + logger.warning('SKIPPING BATCH IN VALIDATION SET') + continue + current_step.value = global_step if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - with accelerator.accumulate(*training_models): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) - else: - with torch.no_grad(): - # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + # CALCULATE LOSS ON TEST SET AT TEST SET FREQUENCY + if global_step==0: + test_fixed_states = [] + test_losses = [] + if global_step % test_step_freq == 0 and test_step_freq > 0: + test_loss, test_fixed_states = train_util.calc_test_val_loss(dataset=test_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=test_fixed_states, test=True) + test_losses.append(test_loss) - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + # CALCULATE LOSS ON VALIDATION SET AT TEST SET FREQUENCY + if global_step==0: + val_fixed_states = [] + val_losses = [] + if global_step % val_step_freq == 0 and val_step_freq > 0: + val_loss, val_fixed_states = train_util.calc_test_val_loss(dataset=val_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=val_fixed_states, test=False) + val_losses.append(val_loss) - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encoder_conds = text_encoder_outputs_list - else: - # not cached or training, so get from text encoders - tokens_and_masks = batch["input_ids_list"] - with torch.no_grad(): - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( - flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask - ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # STANDARD LOSS CALCULATION + loss, _ = calculate_loss(step, batch, accumulate_loss=True) # Loss should be accumulated when not running the test/validation samples though - # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + # backward + accelerator.backward(loss) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype - ) - - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # call model - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - if args.bypass_flux_guidance: - flux_utils.bypass_flux_guidance(flux) - - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = flux( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - if args.bypass_flux_guidance: - flux_utils.restore_flux_guidance(flux) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # calculate loss - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - loss = loss.mean() - - # backward - accelerator.backward(loss) - - if not (args.fused_backward_pass or args.blockwise_fused_optimizers): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - else: - # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook - lr_scheduler.step() - if args.blockwise_fused_optimizers: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..465952d0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6398,3 +6398,30 @@ class LossRecorder: @property def moving_average(self) -> float: return self.loss_total / len(self.loss_list) + +def calc_test_val_loss(dataset, loss_func, repeat_count, fixed_states=[], test=True): + test_val_ind = 'TEST' if test else 'VALIDATION' + # logger.warning(f'CALCULATING {test_val_ind} LOSS') + losses = [] + for step, batch in enumerate(dataset['batches'] * repeat_count): + if len(fixed_states) < len(dataset['batches']) * repeat_count: # If accumulating fixed states, calculate state as normal and return + loss, state = loss_func(step, batch, None, accumulate_loss=False) + fixed_states.append(state) + else: # Otherwise, recall the stored values and use those instead so the test loss is consistently calculated for each sample + state = fixed_states[step] + loss, _ = loss_func(step, batch, state, accumulate_loss=False) + losses.append(loss.detach().item()) + avg_loss = sum(losses) / len(losses) + logger.info(f'AVERAGE {test_val_ind} LOSS: {avg_loss:.6f}') + return avg_loss, fixed_states + +def create_test_val_set(dataloader, test_set_count, val_set_count): + test_set = test_set = {'steps':list(range(test_set_count)), 'batches':[]} + val_set = {'steps':list(range(test_set_count,test_set_count+val_set_count)), 'batches':[]} + for step, batch in enumerate(dataloader): + if step in test_set['steps']: + test_set['batches'].append(batch) + if step in val_set['steps']: + val_set['batches'].append(batch) + if step >= test_set_count + val_set_count: + return test_set, val_set \ No newline at end of file