mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Diff Output Preserv loss for SDXL
This commit is contained in:
@@ -143,7 +143,7 @@ class NetworkTrainer:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||
return noise_pred
|
||||
|
||||
@@ -218,6 +218,30 @@ class NetworkTrainer:
|
||||
else:
|
||||
target = noise
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
noise_pred_prior = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
indices=diff_output_pr_indices,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, target, timesteps, huber_c, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
@@ -1123,15 +1147,6 @@ class NetworkTrainer:
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
# # SD only
|
||||
# encoded_text_encoder_conds = get_weighted_text_embeddings(
|
||||
# tokenizers[0],
|
||||
# text_encoder,
|
||||
# batch["captions"],
|
||||
# accelerator.device,
|
||||
# args.max_token_length // 75 if args.max_token_length else 1,
|
||||
# clip_skip=args.clip_skip,
|
||||
# )
|
||||
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy,
|
||||
|
||||
Reference in New Issue
Block a user