mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Use self.get_noise_pred_and_target and drop fixed timesteps
This commit is contained in:
116
train_network.py
116
train_network.py
@@ -223,6 +223,7 @@ class NetworkTrainer:
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
):
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -236,7 +237,7 @@ class NetworkTrainer:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
@@ -317,7 +318,7 @@ class NetworkTrainer:
|
||||
|
||||
# endregion
|
||||
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor:
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -372,91 +373,40 @@ class NetworkTrainer:
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# Sample noise,
|
||||
noise = train_util.make_noise(args, latents)
|
||||
|
||||
def pick_timesteps_list() -> torch.IntTensor:
|
||||
if timesteps_list is None or timesteps_list == []:
|
||||
return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1))
|
||||
else:
|
||||
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))
|
||||
# Predict the noise residual
|
||||
# and add noise to the latents
|
||||
# with noise offset and/or multires noise if specified
|
||||
|
||||
chosen_timesteps_list = pick_timesteps_list()
|
||||
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=is_train
|
||||
)
|
||||
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
for fixed_timesteps in chosen_timesteps_list:
|
||||
fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# Predict the noise residual
|
||||
# and add noise to the latents
|
||||
# with noise offset and/or multires noise if specified
|
||||
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps)
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
for x in noisy_latents:
|
||||
x.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
noise_pred_prior = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents,
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
indices=diff_output_pr_indices,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler)
|
||||
|
||||
total_loss += loss
|
||||
|
||||
return total_loss / len(chosen_timesteps_list)
|
||||
return loss.mean()
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1416,7 +1366,7 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
||||
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
||||
val_progress_bar.update(1)
|
||||
@@ -1447,7 +1397,7 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
|
||||
Reference in New Issue
Block a user