mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Use self.get_noise_pred_and_target and drop fixed timesteps
This commit is contained in:
@@ -339,6 +339,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
|
is_train=True
|
||||||
):
|
):
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
@@ -375,7 +376,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||||
# if not args.split_mode:
|
# if not args.split_mode:
|
||||||
# normal forward
|
# normal forward
|
||||||
with accelerator.autocast():
|
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = unet(
|
model_pred = unet(
|
||||||
img=img,
|
img=img,
|
||||||
@@ -420,7 +421,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
intermediate_txt.requires_grad_(True)
|
intermediate_txt.requires_grad_(True)
|
||||||
vec.requires_grad_(True)
|
vec.requires_grad_(True)
|
||||||
pe.requires_grad_(True)
|
pe.requires_grad_(True)
|
||||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
|
||||||
|
with torch.set_grad_enabled(is_train and train_unet):
|
||||||
|
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return model_pred
|
return model_pred
|
||||||
|
|||||||
@@ -312,6 +312,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
|
is_train=True
|
||||||
):
|
):
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
@@ -339,7 +340,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
# call model
|
# call model
|
||||||
with accelerator.autocast():
|
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||||
# TODO support attention mask
|
# TODO support attention mask
|
||||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||||
|
|
||||||
|
|||||||
116
train_network.py
116
train_network.py
@@ -223,6 +223,7 @@ class NetworkTrainer:
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
|
is_train=True
|
||||||
):
|
):
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
@@ -236,7 +237,7 @@ class NetworkTrainer:
|
|||||||
t.requires_grad_(True)
|
t.requires_grad_(True)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with accelerator.autocast():
|
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
args,
|
args,
|
||||||
accelerator,
|
accelerator,
|
||||||
@@ -317,7 +318,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor:
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
@@ -372,91 +373,40 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
batch_size = latents.shape[0]
|
batch_size = latents.shape[0]
|
||||||
|
|
||||||
# Sample noise,
|
|
||||||
noise = train_util.make_noise(args, latents)
|
|
||||||
|
|
||||||
def pick_timesteps_list() -> torch.IntTensor:
|
# Predict the noise residual
|
||||||
if timesteps_list is None or timesteps_list == []:
|
# and add noise to the latents
|
||||||
return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1))
|
# with noise offset and/or multires noise if specified
|
||||||
else:
|
|
||||||
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))
|
|
||||||
|
|
||||||
chosen_timesteps_list = pick_timesteps_list()
|
# sample noise, call unet, get target
|
||||||
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
|
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||||
|
args,
|
||||||
|
accelerator,
|
||||||
|
noise_scheduler,
|
||||||
|
latents,
|
||||||
|
batch,
|
||||||
|
text_encoder_conds,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
|
weight_dtype,
|
||||||
|
train_unet,
|
||||||
|
is_train=is_train
|
||||||
|
)
|
||||||
|
|
||||||
# Use input timesteps_list or use described timesteps above
|
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||||
for fixed_timesteps in chosen_timesteps_list:
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||||
fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps)
|
if weighting is not None:
|
||||||
|
loss = loss * weighting
|
||||||
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
|
loss = apply_masked_loss(loss, batch)
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
# Predict the noise residual
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
# and add noise to the latents
|
loss = loss * loss_weights
|
||||||
# with noise offset and/or multires noise if specified
|
|
||||||
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps)
|
|
||||||
|
|
||||||
# ensure the hidden state will require grad
|
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||||
if args.gradient_checkpointing:
|
|
||||||
for x in noisy_latents:
|
|
||||||
x.requires_grad_(True)
|
|
||||||
for t in text_encoder_conds:
|
|
||||||
t.requires_grad_(True)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
return loss.mean()
|
||||||
noise_pred = self.call_unet(
|
|
||||||
args,
|
|
||||||
accelerator,
|
|
||||||
unet,
|
|
||||||
noisy_latents.requires_grad_(train_unet),
|
|
||||||
fixed_timesteps,
|
|
||||||
text_encoder_conds,
|
|
||||||
batch,
|
|
||||||
weight_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.v_parameterization:
|
|
||||||
# v-parameterization training
|
|
||||||
target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps)
|
|
||||||
else:
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
# differential output preservation
|
|
||||||
if "custom_attributes" in batch:
|
|
||||||
diff_output_pr_indices = []
|
|
||||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
|
||||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
|
||||||
diff_output_pr_indices.append(i)
|
|
||||||
|
|
||||||
if len(diff_output_pr_indices) > 0:
|
|
||||||
network.set_multiplier(0.0)
|
|
||||||
with torch.no_grad(), accelerator.autocast():
|
|
||||||
noise_pred_prior = self.call_unet(
|
|
||||||
args,
|
|
||||||
accelerator,
|
|
||||||
unet,
|
|
||||||
noisy_latents,
|
|
||||||
fixed_timesteps,
|
|
||||||
text_encoder_conds,
|
|
||||||
batch,
|
|
||||||
weight_dtype,
|
|
||||||
indices=diff_output_pr_indices,
|
|
||||||
)
|
|
||||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
|
||||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
|
||||||
|
|
||||||
huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler)
|
|
||||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
|
||||||
loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし
|
|
||||||
|
|
||||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
|
||||||
loss = apply_masked_loss(loss, batch)
|
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
|
||||||
loss = loss * loss_weights
|
|
||||||
|
|
||||||
loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler)
|
|
||||||
|
|
||||||
total_loss += loss
|
|
||||||
|
|
||||||
return total_loss / len(chosen_timesteps_list)
|
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args):
|
||||||
session_id = random.randint(0, 2**32)
|
session_id = random.randint(0, 2**32)
|
||||||
@@ -1416,7 +1366,7 @@ class NetworkTrainer:
|
|||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
||||||
|
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
@@ -1447,7 +1397,7 @@ class NetworkTrainer:
|
|||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||||
|
|||||||
Reference in New Issue
Block a user