Diff Output Preserv loss for SDXL

This commit is contained in:
Kohya S
2024-10-18 20:57:13 +09:00
parent 2500f5a798
commit 3cc5b8db99
4 changed files with 67 additions and 22 deletions

View File

@@ -143,7 +143,7 @@ class NetworkTrainer:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred
@@ -218,6 +218,30 @@ class NetworkTrainer:
else:
target = noise
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
noise_pred_prior = self.call_unet(
args,
accelerator,
unet,
noisy_latents,
timesteps,
text_encoder_conds,
batch,
weight_dtype,
indices=diff_output_pr_indices,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, target, timesteps, huber_c, None
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
@@ -1123,15 +1147,6 @@ class NetworkTrainer:
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
# # SD only
# encoded_text_encoder_conds = get_weighted_text_embeddings(
# tokenizers[0],
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,