From b203e318774e9a8a6d64a2aea9173cf259dc70ea Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Wed, 22 Jan 2025 19:13:25 -0500 Subject: [PATCH 1/7] Minimal Example of Flex Training --- flux_minimal_inference.py | 8 ++++++++ flux_train_network.py | 4 ++++ library/flux_utils.py | 27 +++++++++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 7ab224f1..b4021bd1 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -20,6 +20,8 @@ from library import device_utils from library.device_utils import init_ipex, get_preferred_device from networks import oft_flux +from library.flux_utils import bypass_flux_guidance, restore_flux_guidance + init_ipex() @@ -151,6 +153,9 @@ def do_sample( logger.info(f"num_steps: {num_steps}") timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + # bypass guidance module + bypass_flux_guidance(model) + # denoise initial noise if accelerator: with accelerator.autocast(), torch.no_grad(): @@ -364,6 +369,9 @@ def generate_image( x = x.permute(0, 2, 3, 1) img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + # restore guidance module + restore_flux_guidance(model) + # save image output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) diff --git a/flux_train_network.py b/flux_train_network.py index 75e975ba..3035d716 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -425,6 +425,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return model_pred + flux_utils.bypass_flux_guidance(unet) + model_pred = call_dit( img=packed_noisy_model_input, img_ids=img_ids, @@ -439,6 +441,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + flux_utils.restore_flux_guidance(unet) + # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..95df71cd 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -24,6 +24,32 @@ MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" +def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) + pooled_projections = self.text_embedder(pooled_projection) + conditioning = timesteps_emb + pooled_projections + return conditioning + +# bypass the forward function +def bypass_flux_guidance(transformer): + if hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + # dont bypass if it doesnt have the guidance embedding + if not hasattr(transformer.time_text_embed, 'guidance_embedder'): + return + transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward + transformer.time_text_embed.forward = partial( + guidance_embed_bypass_forward, transformer.time_text_embed + ) + +# restore the forward function +def restore_flux_guidance(transformer): + if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'): + return + transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward + del transformer.time_text_embed._bfg_orig_forward def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ @@ -60,6 +86,7 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) + # is_schnell = True # check number of double and single blocks if not is_diffusers: From 05fd3f763f2338cf8cdf1860bb9ed465751a8016 Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Wed, 22 Jan 2025 19:32:39 -0500 Subject: [PATCH 2/7] Add command line argument for bypassing flux guidance --- flux_minimal_inference.py | 6 ++++-- flux_train_network.py | 9 +++++---- library/train_util.py | 8 ++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index b4021bd1..c470bcb4 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -154,7 +154,8 @@ def do_sample( timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) # bypass guidance module - bypass_flux_guidance(model) + if args.bypass_flux_guidance: + bypass_flux_guidance(model) # denoise initial noise if accelerator: @@ -370,7 +371,8 @@ def generate_image( img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) # restore guidance module - restore_flux_guidance(model) + if args.bypass_flux_guidance: + restore_flux_guidance(model) # save image output_dir = args.output_dir diff --git a/flux_train_network.py b/flux_train_network.py index 3035d716..ed578168 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -424,8 +424,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): """ return model_pred - - flux_utils.bypass_flux_guidance(unet) + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(unet) model_pred = call_dit( img=packed_noisy_model_input, @@ -440,8 +440,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - flux_utils.restore_flux_guidance(unet) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(unet) # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..3180d694 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4103,6 +4103,14 @@ def add_dit_training_arguments(parser: argparse.ArgumentParser): "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) + # bypass guidance module for flux + parser.add_argument( + "--bypass_flux_guidance" + , action="store_true" + , help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール" + ) + + def get_sanitized_config_or_none(args: argparse.Namespace): # if `--log_config` is enabled, return args for logging. if not, return None. From cafc5d78de681a021a29c3d176c9390ef23858af Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Wed, 22 Jan 2025 19:39:33 -0500 Subject: [PATCH 3/7] Move command line argument from train_util to flux_train_util --- library/flux_train_utils.py | 7 +++++++ library/train_util.py | 8 -------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5c..3bbd1fad 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -617,3 +617,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + # bypass guidance module for flux + parser.add_argument( + "--bypass_flux_guidance" + , action="store_true" + , help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール" + ) diff --git a/library/train_util.py b/library/train_util.py index 3180d694..72b5b24d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4103,14 +4103,6 @@ def add_dit_training_arguments(parser: argparse.ArgumentParser): "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) - # bypass guidance module for flux - parser.add_argument( - "--bypass_flux_guidance" - , action="store_true" - , help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール" - ) - - def get_sanitized_config_or_none(args: argparse.Namespace): # if `--log_config` is enabled, return args for logging. if not, return None. From a768d53d77ebc5f531ec61ab93915a51106becb9 Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:39:45 -0500 Subject: [PATCH 4/7] Add bypass flux guidance to flux_train.py --- flux_train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flux_train.py b/flux_train.py index fced3bef..2c1ec8f3 100644 --- a/flux_train.py +++ b/flux_train.py @@ -642,6 +642,9 @@ def train(args): l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds if not args.apply_t5_attn_mask: t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) with accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) @@ -659,6 +662,9 @@ def train(args): # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + # apply model prediction type model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) From 1ade5825f7f02fe2add673bbaeab44b1da0403ba Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Thu, 23 Jan 2025 16:41:58 -0500 Subject: [PATCH 5/7] Updated guidance bypass mechanism to use built-in Flux.params.guidance_embed bool --- library/flux_utils.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/library/flux_utils.py b/library/flux_utils.py index 95df71cd..309f0772 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -24,32 +24,13 @@ MODEL_VERSION_FLUX_V1 = "flux1" MODEL_NAME_DEV = "dev" MODEL_NAME_SCHNELL = "schnell" -def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection): - timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder( - timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) - pooled_projections = self.text_embedder(pooled_projection) - conditioning = timesteps_emb + pooled_projections - return conditioning - -# bypass the forward function +# bypass guidance def bypass_flux_guidance(transformer): - if hasattr(transformer.time_text_embed, '_bfg_orig_forward'): - return - # dont bypass if it doesnt have the guidance embedding - if not hasattr(transformer.time_text_embed, 'guidance_embedder'): - return - transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward - transformer.time_text_embed.forward = partial( - guidance_embed_bypass_forward, transformer.time_text_embed - ) + transformer.params.guidance_embed = False # restore the forward function def restore_flux_guidance(transformer): - if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'): - return - transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward - del transformer.time_text_embed._bfg_orig_forward + transformer.params.guidance_embed = True def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]: """ @@ -86,7 +67,6 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) - # is_schnell = True # check number of double and single blocks if not is_diffusers: From 0b50630e6117af7c6ee8453b103136feb31bf7eb Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Sat, 25 Jan 2025 09:54:25 -0500 Subject: [PATCH 6/7] Create basic Flux calc for test and validation loss --- flux_train.py | 264 +++++++++++++++++++++++++----------------- library/train_util.py | 27 +++++ 2 files changed, 184 insertions(+), 107 deletions(-) diff --git a/flux_train.py b/flux_train.py index 2c1ec8f3..6fad0c31 100644 --- a/flux_train.py +++ b/flux_train.py @@ -19,6 +19,7 @@ from multiprocessing import Value import time from typing import List, Optional, Tuple, Union import toml +import random from tqdm import tqdm @@ -44,6 +45,8 @@ logger = logging.getLogger(__name__) import library.config_util as config_util +from contextlib import nullcontext + # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( ConfigSanitizer, @@ -51,7 +54,6 @@ from library.config_util import ( ) from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -577,6 +579,123 @@ def train(args): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) + + ### PLACEHOLDERS ### + test_step_freq = 10 + val_step_freq = 25 + test_set_count = 5 + val_set_count = 5 + test_val_repeat_count = 2 + + logger.warning('CREATING TEST AND VALIDATION SETS') + test_set, val_set = train_util.create_test_val_set(train_dataloader, test_set_count, val_set_count) + + # TODO: Get arguments for step_freq values + # TODO: Get arguments for test_set_count, test_noise_iter + + def calculate_loss(step=step, batch=batch, state=None, accumulate_loss: bool=True, accelerator=accelerator): + + if state is not None: + noise, noisy_model_input, timesteps, sigmas = state + + with accelerator.accumulate(*training_models) if accumulate_loss else nullcontext(): # Only utilize the accumulate context if loss is marked to be accumulated, otherwise, just use a null context. This avoids the test and validation samples impacting the training. + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) + + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + + bsz = latents.shape[0] + + # get noisy model input and timesteps + if state is None: # Only calculate if not using stored values for validation + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + state = (noise, noisy_model_input, timesteps, sigmas) + + return loss, state + loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -587,122 +706,53 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): + if step in val_set['steps']: # Skip validation steps, don't increment global step + logger.warning('SKIPPING BATCH IN VALIDATION SET') + continue + current_step.value = global_step if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - with accelerator.accumulate(*training_models): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) - else: - with torch.no_grad(): - # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) + # CALCULATE LOSS ON TEST SET AT TEST SET FREQUENCY + if global_step==0: + test_fixed_states = [] + test_losses = [] + if global_step % test_step_freq == 0 and test_step_freq > 0: + test_loss, test_fixed_states = train_util.calc_test_val_loss(dataset=test_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=test_fixed_states, test=True) + test_losses.append(test_loss) - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) + # CALCULATE LOSS ON VALIDATION SET AT TEST SET FREQUENCY + if global_step==0: + val_fixed_states = [] + val_losses = [] + if global_step % val_step_freq == 0 and val_step_freq > 0: + val_loss, val_fixed_states = train_util.calc_test_val_loss(dataset=val_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=val_fixed_states, test=False) + val_losses.append(val_loss) - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encoder_conds = text_encoder_outputs_list - else: - # not cached or training, so get from text encoders - tokens_and_masks = batch["input_ids_list"] - with torch.no_grad(): - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( - flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask - ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] + # STANDARD LOSS CALCULATION + loss, _ = calculate_loss(step, batch, accumulate_loss=True) # Loss should be accumulated when not running the test/validation samples though - # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps + # backward + accelerator.backward(loss) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype - ) - - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # call model - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - if args.bypass_flux_guidance: - flux_utils.bypass_flux_guidance(flux) - - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = flux( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - if args.bypass_flux_guidance: - flux_utils.restore_flux_guidance(flux) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # calculate loss - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - loss = loss.mean() - - # backward - accelerator.backward(loss) - - if not (args.fused_backward_pass or args.blockwise_fused_optimizers): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - else: - # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook - lr_scheduler.step() - if args.blockwise_fused_optimizers: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/library/train_util.py b/library/train_util.py index 72b5b24d..465952d0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6398,3 +6398,30 @@ class LossRecorder: @property def moving_average(self) -> float: return self.loss_total / len(self.loss_list) + +def calc_test_val_loss(dataset, loss_func, repeat_count, fixed_states=[], test=True): + test_val_ind = 'TEST' if test else 'VALIDATION' + # logger.warning(f'CALCULATING {test_val_ind} LOSS') + losses = [] + for step, batch in enumerate(dataset['batches'] * repeat_count): + if len(fixed_states) < len(dataset['batches']) * repeat_count: # If accumulating fixed states, calculate state as normal and return + loss, state = loss_func(step, batch, None, accumulate_loss=False) + fixed_states.append(state) + else: # Otherwise, recall the stored values and use those instead so the test loss is consistently calculated for each sample + state = fixed_states[step] + loss, _ = loss_func(step, batch, state, accumulate_loss=False) + losses.append(loss.detach().item()) + avg_loss = sum(losses) / len(losses) + logger.info(f'AVERAGE {test_val_ind} LOSS: {avg_loss:.6f}') + return avg_loss, fixed_states + +def create_test_val_set(dataloader, test_set_count, val_set_count): + test_set = test_set = {'steps':list(range(test_set_count)), 'batches':[]} + val_set = {'steps':list(range(test_set_count,test_set_count+val_set_count)), 'batches':[]} + for step, batch in enumerate(dataloader): + if step in test_set['steps']: + test_set['batches'].append(batch) + if step in val_set['steps']: + val_set['batches'].append(batch) + if step >= test_set_count + val_set_count: + return test_set, val_set \ No newline at end of file From 99338a204fd17925428fc35bdec7bdac4284b232 Mon Sep 17 00:00:00 2001 From: stepfunction83 <32859451+stepfunction83@users.noreply.github.com> Date: Sat, 25 Jan 2025 10:13:27 -0500 Subject: [PATCH 7/7] Revert "Create basic Flux calc for test and validation loss" This reverts commit 0b50630e6117af7c6ee8453b103136feb31bf7eb. --- flux_train.py | 264 +++++++++++++++++------------------------- library/train_util.py | 27 ----- 2 files changed, 107 insertions(+), 184 deletions(-) diff --git a/flux_train.py b/flux_train.py index 6fad0c31..2c1ec8f3 100644 --- a/flux_train.py +++ b/flux_train.py @@ -19,7 +19,6 @@ from multiprocessing import Value import time from typing import List, Optional, Tuple, Union import toml -import random from tqdm import tqdm @@ -45,8 +44,6 @@ logger = logging.getLogger(__name__) import library.config_util as config_util -from contextlib import nullcontext - # import library.sdxl_train_util as sdxl_train_util from library.config_util import ( ConfigSanitizer, @@ -54,6 +51,7 @@ from library.config_util import ( ) from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments + def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -579,123 +577,6 @@ def train(args): # log empty object to commit the sample images to wandb accelerator.log({}, step=0) - - ### PLACEHOLDERS ### - test_step_freq = 10 - val_step_freq = 25 - test_set_count = 5 - val_set_count = 5 - test_val_repeat_count = 2 - - logger.warning('CREATING TEST AND VALIDATION SETS') - test_set, val_set = train_util.create_test_val_set(train_dataloader, test_set_count, val_set_count) - - # TODO: Get arguments for step_freq values - # TODO: Get arguments for test_set_count, test_noise_iter - - def calculate_loss(step=step, batch=batch, state=None, accumulate_loss: bool=True, accelerator=accelerator): - - if state is not None: - noise, noisy_model_input, timesteps, sigmas = state - - with accelerator.accumulate(*training_models) if accumulate_loss else nullcontext(): # Only utilize the accumulate context if loss is marked to be accumulated, otherwise, just use a null context. This avoids the test and validation samples impacting the training. - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) - else: - with torch.no_grad(): - # encode images to latents. images are [-1, 1] - latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) - - # NaNが含まれていれば警告を表示し0に置き換える - if torch.any(torch.isnan(latents)): - accelerator.print("NaN found in latents, replacing with zeros") - latents = torch.nan_to_num(latents, 0, out=latents) - - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encoder_conds = text_encoder_outputs_list - else: - # not cached or training, so get from text encoders - tokens_and_masks = batch["input_ids_list"] - with torch.no_grad(): - input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - text_encoder_conds = text_encoding_strategy.encode_tokens( - flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask - ) - if args.full_fp16: - text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - - # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps - - bsz = latents.shape[0] - - # get noisy model input and timesteps - if state is None: # Only calculate if not using stored values for validation - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype - ) - - # pack latents and get img_ids - packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 - packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 - img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) - - # get guidance: ensure args.guidance_scale is float - guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) - - # call model - l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds - if not args.apply_t5_attn_mask: - t5_attn_mask = None - - if args.bypass_flux_guidance: - flux_utils.bypass_flux_guidance(flux) - - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = flux( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - - # unpack latents - model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) - - if args.bypass_flux_guidance: - flux_utils.restore_flux_guidance(flux) - - # apply model prediction type - model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) - - # flow matching loss: this is different from SD3 - target = noise - latents - - # calculate loss - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - loss = loss.mean() - - state = (noise, noisy_model_input, timesteps, sigmas) - - return loss, state - loss_recorder = train_util.LossRecorder() epoch = 0 # avoid error when max_train_steps is 0 for epoch in range(num_train_epochs): @@ -706,53 +587,122 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): - if step in val_set['steps']: # Skip validation steps, don't increment global step - logger.warning('SKIPPING BATCH IN VALIDATION SET') - continue - current_step.value = global_step if args.blockwise_fused_optimizers: optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step - # CALCULATE LOSS ON TEST SET AT TEST SET FREQUENCY - if global_step==0: - test_fixed_states = [] - test_losses = [] - if global_step % test_step_freq == 0 and test_step_freq > 0: - test_loss, test_fixed_states = train_util.calc_test_val_loss(dataset=test_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=test_fixed_states, test=True) - test_losses.append(test_loss) + with accelerator.accumulate(*training_models): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) + else: + with torch.no_grad(): + # encode images to latents. images are [-1, 1] + latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype) - # CALCULATE LOSS ON VALIDATION SET AT TEST SET FREQUENCY - if global_step==0: - val_fixed_states = [] - val_losses = [] - if global_step % val_step_freq == 0 and val_step_freq > 0: - val_loss, val_fixed_states = train_util.calc_test_val_loss(dataset=val_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=val_fixed_states, test=False) - val_losses.append(val_loss) + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.nan_to_num(latents, 0, out=latents) - # STANDARD LOSS CALCULATION - loss, _ = calculate_loss(step, batch, accumulate_loss=True) # Loss should be accumulated when not running the test/validation samples though + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encoder_conds = text_encoder_outputs_list + else: + # not cached or training, so get from text encoders + tokens_and_masks = batch["input_ids_list"] + with torch.no_grad(): + input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] + text_encoder_conds = text_encoding_strategy.encode_tokens( + flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask + ) + if args.full_fp16: + text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds] - # backward - accelerator.backward(loss) + # TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps - if not (args.fused_backward_pass or args.blockwise_fused_optimizers): - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - else: - # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook - lr_scheduler.step() - if args.blockwise_fused_optimizers: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() + # get noisy model input and timesteps + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype + ) + + # pack latents and get img_ids + packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 + packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 + img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device) + + # get guidance: ensure args.guidance_scale is float + guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) + + # call model + l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds + if not args.apply_t5_attn_mask: + t5_attn_mask = None + + if args.bypass_flux_guidance: + flux_utils.bypass_flux_guidance(flux) + + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = flux( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # unpack latents + model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) + + if args.bypass_flux_guidance: + flux_utils.restore_flux_guidance(flux) + + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) + + # flow matching loss: this is different from SD3 + target = noise - latents + + # calculate loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + loss = loss.mean() + + # backward + accelerator.backward(loss) + + if not (args.fused_backward_pass or args.blockwise_fused_optimizers): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.blockwise_fused_optimizers: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/library/train_util.py b/library/train_util.py index 465952d0..72b5b24d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6398,30 +6398,3 @@ class LossRecorder: @property def moving_average(self) -> float: return self.loss_total / len(self.loss_list) - -def calc_test_val_loss(dataset, loss_func, repeat_count, fixed_states=[], test=True): - test_val_ind = 'TEST' if test else 'VALIDATION' - # logger.warning(f'CALCULATING {test_val_ind} LOSS') - losses = [] - for step, batch in enumerate(dataset['batches'] * repeat_count): - if len(fixed_states) < len(dataset['batches']) * repeat_count: # If accumulating fixed states, calculate state as normal and return - loss, state = loss_func(step, batch, None, accumulate_loss=False) - fixed_states.append(state) - else: # Otherwise, recall the stored values and use those instead so the test loss is consistently calculated for each sample - state = fixed_states[step] - loss, _ = loss_func(step, batch, state, accumulate_loss=False) - losses.append(loss.detach().item()) - avg_loss = sum(losses) / len(losses) - logger.info(f'AVERAGE {test_val_ind} LOSS: {avg_loss:.6f}') - return avg_loss, fixed_states - -def create_test_val_set(dataloader, test_set_count, val_set_count): - test_set = test_set = {'steps':list(range(test_set_count)), 'batches':[]} - val_set = {'steps':list(range(test_set_count,test_set_count+val_set_count)), 'batches':[]} - for step, batch in enumerate(dataloader): - if step in test_set['steps']: - test_set['batches'].append(batch) - if step in val_set['steps']: - val_set['batches'].append(batch) - if step >= test_set_count + val_set_count: - return test_set, val_set \ No newline at end of file