mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
Revert "Create basic Flux calc for test and validation loss"
This reverts commit 0b50630e61.
This commit is contained in:
264
flux_train.py
264
flux_train.py
@@ -19,7 +19,6 @@ from multiprocessing import Value
|
||||
import time
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import toml
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -45,8 +44,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
|
||||
from contextlib import nullcontext
|
||||
|
||||
# import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
@@ -54,6 +51,7 @@ from library.config_util import (
|
||||
)
|
||||
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
@@ -579,123 +577,6 @@ def train(args):
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
|
||||
### PLACEHOLDERS ###
|
||||
test_step_freq = 10
|
||||
val_step_freq = 25
|
||||
test_set_count = 5
|
||||
val_set_count = 5
|
||||
test_val_repeat_count = 2
|
||||
|
||||
logger.warning('CREATING TEST AND VALIDATION SETS')
|
||||
test_set, val_set = train_util.create_test_val_set(train_dataloader, test_set_count, val_set_count)
|
||||
|
||||
# TODO: Get arguments for step_freq values
|
||||
# TODO: Get arguments for test_set_count, test_noise_iter
|
||||
|
||||
def calculate_loss(step=step, batch=batch, state=None, accumulate_loss: bool=True, accelerator=accelerator):
|
||||
|
||||
if state is not None:
|
||||
noise, noisy_model_input, timesteps, sigmas = state
|
||||
|
||||
with accelerator.accumulate(*training_models) if accumulate_loss else nullcontext(): # Only utilize the accumulate context if loss is marked to be accumulated, otherwise, just use a null context. This avoids the test and validation samples impacting the training.
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
else:
|
||||
# not cached or training, so get from text encoders
|
||||
tokens_and_masks = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
if state is None: # Only calculate if not using stored values for validation
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
# get guidance: ensure args.guidance_scale is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.bypass_flux_guidance(flux)
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.restore_flux_guidance(flux)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
state = (noise, noisy_model_input, timesteps, sigmas)
|
||||
|
||||
return loss, state
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0 # avoid error when max_train_steps is 0
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -706,53 +587,122 @@ def train(args):
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if step in val_set['steps']: # Skip validation steps, don't increment global step
|
||||
logger.warning('SKIPPING BATCH IN VALIDATION SET')
|
||||
continue
|
||||
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
# CALCULATE LOSS ON TEST SET AT TEST SET FREQUENCY
|
||||
if global_step==0:
|
||||
test_fixed_states = []
|
||||
test_losses = []
|
||||
if global_step % test_step_freq == 0 and test_step_freq > 0:
|
||||
test_loss, test_fixed_states = train_util.calc_test_val_loss(dataset=test_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=test_fixed_states, test=True)
|
||||
test_losses.append(test_loss)
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# CALCULATE LOSS ON VALIDATION SET AT TEST SET FREQUENCY
|
||||
if global_step==0:
|
||||
val_fixed_states = []
|
||||
val_losses = []
|
||||
if global_step % val_step_freq == 0 and val_step_freq > 0:
|
||||
val_loss, val_fixed_states = train_util.calc_test_val_loss(dataset=val_set, loss_func=calculate_loss, repeat_count=test_val_repeat_count, fixed_states=val_fixed_states, test=False)
|
||||
val_losses.append(val_loss)
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
# STANDARD LOSS CALCULATION
|
||||
loss, _ = calculate_loss(step, batch, accumulate_loss=True) # Loss should be accumulated when not running the test/validation samples though
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
else:
|
||||
# not cached or training, so get from text encoders
|
||||
tokens_and_masks = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# backward
|
||||
accelerator.backward(loss)
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
|
||||
# get guidance: ensure args.guidance_scale is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.bypass_flux_guidance(flux)
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.restore_flux_guidance(flux)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# calculate loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
# backward
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -6398,30 +6398,3 @@ class LossRecorder:
|
||||
@property
|
||||
def moving_average(self) -> float:
|
||||
return self.loss_total / len(self.loss_list)
|
||||
|
||||
def calc_test_val_loss(dataset, loss_func, repeat_count, fixed_states=[], test=True):
|
||||
test_val_ind = 'TEST' if test else 'VALIDATION'
|
||||
# logger.warning(f'CALCULATING {test_val_ind} LOSS')
|
||||
losses = []
|
||||
for step, batch in enumerate(dataset['batches'] * repeat_count):
|
||||
if len(fixed_states) < len(dataset['batches']) * repeat_count: # If accumulating fixed states, calculate state as normal and return
|
||||
loss, state = loss_func(step, batch, None, accumulate_loss=False)
|
||||
fixed_states.append(state)
|
||||
else: # Otherwise, recall the stored values and use those instead so the test loss is consistently calculated for each sample
|
||||
state = fixed_states[step]
|
||||
loss, _ = loss_func(step, batch, state, accumulate_loss=False)
|
||||
losses.append(loss.detach().item())
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
logger.info(f'AVERAGE {test_val_ind} LOSS: {avg_loss:.6f}')
|
||||
return avg_loss, fixed_states
|
||||
|
||||
def create_test_val_set(dataloader, test_set_count, val_set_count):
|
||||
test_set = test_set = {'steps':list(range(test_set_count)), 'batches':[]}
|
||||
val_set = {'steps':list(range(test_set_count,test_set_count+val_set_count)), 'batches':[]}
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step in test_set['steps']:
|
||||
test_set['batches'].append(batch)
|
||||
if step in val_set['steps']:
|
||||
val_set['batches'].append(batch)
|
||||
if step >= test_set_count + val_set_count:
|
||||
return test_set, val_set
|
||||
Reference in New Issue
Block a user