mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Use self.get_noise_pred_and_target and drop fixed timesteps
This commit is contained in:
@@ -339,6 +339,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -375,7 +376,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# if not args.split_mode:
|
||||
# normal forward
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
@@ -420,6 +421,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet):
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
"""
|
||||
|
||||
|
||||
@@ -312,6 +312,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -339,7 +340,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
t5_attn_mask = None
|
||||
|
||||
# call model
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
# TODO support attention mask
|
||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||
|
||||
|
||||
@@ -223,6 +223,7 @@ class NetworkTrainer:
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -236,7 +237,7 @@ class NetworkTrainer:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
@@ -317,7 +318,7 @@ class NetworkTrainer:
|
||||
|
||||
# endregion
|
||||
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor:
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -372,91 +373,40 @@ class NetworkTrainer:
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# Sample noise,
|
||||
noise = train_util.make_noise(args, latents)
|
||||
|
||||
def pick_timesteps_list() -> torch.IntTensor:
|
||||
if timesteps_list is None or timesteps_list == []:
|
||||
return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1))
|
||||
else:
|
||||
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))
|
||||
|
||||
chosen_timesteps_list = pick_timesteps_list()
|
||||
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
|
||||
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
for fixed_timesteps in chosen_timesteps_list:
|
||||
fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
# and add noise to the latents
|
||||
# with noise offset and/or multires noise if specified
|
||||
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
for x in noisy_latents:
|
||||
x.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=is_train
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
noise_pred_prior = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents,
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
indices=diff_output_pr_indices,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler)
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
total_loss += loss
|
||||
|
||||
return total_loss / len(chosen_timesteps_list)
|
||||
return loss.mean()
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1416,7 +1366,7 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
||||
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
||||
val_progress_bar.update(1)
|
||||
@@ -1447,7 +1397,7 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
|
||||
Reference in New Issue
Block a user