mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix SD3 LoRA training to work (WIP)
This commit is contained in:
@@ -1151,6 +1151,17 @@ class NetworkTrainer:
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||
|
||||
# if text_encoder_outputs_list is not None:
|
||||
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list
|
||||
# for i in range(len(lg_out)):
|
||||
# print(
|
||||
# f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
|
||||
# f"cached T5: {t5_out[i].max()}, "
|
||||
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
|
||||
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
|
||||
# )
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
@@ -1182,6 +1193,15 @@ class NetworkTrainer:
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
|
||||
# for i in range(len(lg_out)):
|
||||
# print(
|
||||
# f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
|
||||
# f"train T5: {t5_out[i].max()}, "
|
||||
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
|
||||
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
|
||||
# )
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
|
||||
Reference in New Issue
Block a user