Revert "Create basic Flux calc for test and validation loss"

This reverts commit 0b50630e61.
This commit is contained in:
stepfunction83
2025-01-25 10:13:27 -05:00
parent 0b50630e61
commit 99338a204f
2 changed files with 107 additions and 184 deletions

View File

@@ -19,7 +19,6 @@ from multiprocessing import Value
import time
from typing import List, Optional, Tuple, Union
import toml
import random
from tqdm import tqdm
@@ -45,8 +44,6 @@ logger = logging.getLogger(__name__)
import library.config_util as config_util
from contextlib import nullcontext
# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
@@ -54,6 +51,7 @@ from library.config_util import (
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
@@ -579,123 +577,6 @@ def train(args):
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
### PLACEHOLDERS ###
test_step_freq = 10
val_step_freq = 25
test_set_count = 5
val_set_count = 5
test_val_repeat_count = 2
logger.warning('CREATING TEST AND VALIDATION SETS')
test_set, val_set = train_util.create_test_val_set(train_dataloader, test_set_count, val_set_count)
# TODO: Get arguments for step_freq values
# TODO: Get arguments for test_set_count, test_noise_iter
def calculate_loss(step=step, batch=batch, state=None, accumulate_loss: bool=True, accelerator=accelerator):
if state is not None:
noise, noisy_model_input, timesteps, sigmas = state
with accelerator.accumulate(*training_models) if accumulate_loss else nullcontext(): # Only utilize the accumulate context if loss is marked to be accumulated, otherwise, just use a null context. This avoids the test and validation samples impacting the training.
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
bsz = latents.shape[0]
# get noisy model input and timesteps
if state is None: # Only calculate if not using stored values for validation
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
if args.bypass_flux_guidance:
flux_utils.bypass_flux_guidance(flux)
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
if args.bypass_flux_guidance:
flux_utils.restore_flux_guidance(flux)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()
state = (noise, noisy_model_input, timesteps, sigmas)
return loss, state
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
@@ -706,53 +587,122 @@ def train(args):
m.train()
for step, batch in enumerate(train_dataloader):
if step in val_set['steps']: # Skip validation steps, don't increment global step
logger.warning('SKIPPING BATCH IN VALIDATION SET')
continue
current_step.value = global_step
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
# CALCULATE LOSS ON TEST SET AT TEST SET FREQUENCY
if global_step==0:
test_fixed_states = []
test_losses = []
if global_step % test_step_freq == 0 and test_step_freq > 0:
test_loss, test_fixed_states = train_util.calc_test_val_loss(dataset=test_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=test_fixed_states, test=True)
test_losses.append(test_loss)
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# CALCULATE LOSS ON VALIDATION SET AT TEST SET FREQUENCY
if global_step==0:
val_fixed_states = []
val_losses = []
if global_step % val_step_freq == 0 and val_step_freq > 0:
val_loss, val_fixed_states = train_util.calc_test_val_loss(dataset=val_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=val_fixed_states, test=False)
val_losses.append(val_loss)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
# STANDARD LOSS CALCULATION
loss, _ = calculate_loss(step, batch, accumulate_loss=True) # Loss should be accumulated when not running the test/validation samples though
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# backward
accelerator.backward(loss)
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
if args.bypass_flux_guidance:
flux_utils.bypass_flux_guidance(flux)
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
if args.bypass_flux_guidance:
flux_utils.restore_flux_guidance(flux)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()
# backward
accelerator.backward(loss)
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@@ -6398,30 +6398,3 @@ class LossRecorder:
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)
def calc_test_val_loss(dataset, loss_func, repeat_count, fixed_states=[], test=True):
test_val_ind = 'TEST' if test else 'VALIDATION'
# logger.warning(f'CALCULATING {test_val_ind} LOSS')
losses = []
for step, batch in enumerate(dataset['batches'] * repeat_count):
if len(fixed_states) < len(dataset['batches']) * repeat_count: # If accumulating fixed states, calculate state as normal and return
loss, state = loss_func(step, batch, None, accumulate_loss=False)
fixed_states.append(state)
else: # Otherwise, recall the stored values and use those instead so the test loss is consistently calculated for each sample
state = fixed_states[step]
loss, _ = loss_func(step, batch, state, accumulate_loss=False)
losses.append(loss.detach().item())
avg_loss = sum(losses) / len(losses)
logger.info(f'AVERAGE {test_val_ind} LOSS: {avg_loss:.6f}')
return avg_loss, fixed_states
def create_test_val_set(dataloader, test_set_count, val_set_count):
test_set = test_set = {'steps':list(range(test_set_count)), 'batches':[]}
val_set = {'steps':list(range(test_set_count,test_set_count+val_set_count)), 'batches':[]}
for step, batch in enumerate(dataloader):
if step in test_set['steps']:
test_set['batches'].append(batch)
if step in val_set['steps']:
val_set['batches'].append(batch)
if step >= test_set_count + val_set_count:
return test_set, val_set