mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix to work dropout_rate for TEs
This commit is contained in:
@@ -363,7 +363,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
noisy_model_input.requires_grad_(True)
|
noisy_model_input.requires_grad_(True)
|
||||||
for t in text_encoder_conds:
|
for t in text_encoder_conds:
|
||||||
if t.dtype.is_floating_point:
|
if t is not None and t.dtype.is_floating_point:
|
||||||
t.requires_grad_(True)
|
t.requires_grad_(True)
|
||||||
img_ids.requires_grad_(True)
|
img_ids.requires_grad_(True)
|
||||||
guidance_vec.requires_grad_(True)
|
guidance_vec.requires_grad_(True)
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -89,19 +89,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|||||||
if apply_t5_attn_mask is None:
|
if apply_t5_attn_mask is None:
|
||||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||||
|
|
||||||
l_tokens, g_tokens, t5_tokens = tokens[:3]
|
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
|
||||||
|
|
||||||
if len(tokens) > 3:
|
|
||||||
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:]
|
|
||||||
if not apply_lg_attn_mask:
|
|
||||||
l_attn_mask = None
|
|
||||||
g_attn_mask = None
|
|
||||||
if not apply_t5_attn_mask:
|
|
||||||
t5_attn_mask = None
|
|
||||||
else:
|
|
||||||
l_attn_mask = None
|
|
||||||
g_attn_mask = None
|
|
||||||
t5_attn_mask = None
|
|
||||||
|
|
||||||
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
|
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
|
||||||
|
|
||||||
@@ -109,47 +97,114 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
|||||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||||
lg_out = None
|
lg_out = None
|
||||||
lg_pooled = None
|
lg_pooled = None
|
||||||
|
l_attn_mask = None
|
||||||
|
g_attn_mask = None
|
||||||
else:
|
else:
|
||||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||||
|
|
||||||
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
|
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
|
||||||
if drop_l:
|
batch_size, l_seq_len = l_tokens.shape
|
||||||
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype)
|
g_seq_len = g_tokens.shape[1]
|
||||||
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype)
|
|
||||||
if l_attn_mask is not None:
|
|
||||||
l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device)
|
|
||||||
else:
|
|
||||||
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
|
|
||||||
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
|
|
||||||
l_pooled = prompt_embeds[0]
|
|
||||||
l_out = prompt_embeds.hidden_states[-2]
|
|
||||||
|
|
||||||
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
|
non_drop_l_indices = []
|
||||||
if drop_g:
|
non_drop_g_indices = []
|
||||||
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype)
|
for i in range(l_tokens.shape[0]):
|
||||||
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype)
|
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
|
||||||
if g_attn_mask is not None:
|
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
|
||||||
g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device)
|
if not drop_l:
|
||||||
else:
|
non_drop_l_indices.append(i)
|
||||||
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
|
if not drop_g:
|
||||||
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
|
non_drop_g_indices.append(i)
|
||||||
g_pooled = prompt_embeds[0]
|
|
||||||
g_out = prompt_embeds.hidden_states[-2]
|
|
||||||
|
|
||||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
# filter out dropped members
|
||||||
|
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
|
||||||
|
l_tokens = l_tokens[non_drop_l_indices]
|
||||||
|
l_attn_mask = l_attn_mask[non_drop_l_indices]
|
||||||
|
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
|
||||||
|
g_tokens = g_tokens[non_drop_g_indices]
|
||||||
|
g_attn_mask = g_attn_mask[non_drop_g_indices]
|
||||||
|
|
||||||
|
# call clip_l for non-dropped members
|
||||||
|
if len(non_drop_l_indices) > 0:
|
||||||
|
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
|
||||||
|
prompt_embeds = clip_l(
|
||||||
|
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
||||||
|
)
|
||||||
|
nd_l_pooled = prompt_embeds[0]
|
||||||
|
nd_l_out = prompt_embeds.hidden_states[-2]
|
||||||
|
if len(non_drop_g_indices) > 0:
|
||||||
|
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
|
||||||
|
prompt_embeds = clip_g(
|
||||||
|
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
|
||||||
|
)
|
||||||
|
nd_g_pooled = prompt_embeds[0]
|
||||||
|
nd_g_out = prompt_embeds.hidden_states[-2]
|
||||||
|
|
||||||
|
# fill in the dropped members
|
||||||
|
if len(non_drop_l_indices) == batch_size:
|
||||||
|
l_pooled = nd_l_pooled
|
||||||
|
l_out = nd_l_out
|
||||||
|
else:
|
||||||
|
# model output is always float32 because of the models are wrapped with Accelerator
|
||||||
|
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
|
||||||
|
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
|
||||||
|
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
|
||||||
|
if len(non_drop_l_indices) > 0:
|
||||||
|
l_pooled[non_drop_l_indices] = nd_l_pooled
|
||||||
|
l_out[non_drop_l_indices] = nd_l_out
|
||||||
|
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
|
||||||
|
|
||||||
|
if len(non_drop_g_indices) == batch_size:
|
||||||
|
g_pooled = nd_g_pooled
|
||||||
|
g_out = nd_g_out
|
||||||
|
else:
|
||||||
|
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
|
||||||
|
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
|
||||||
|
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
|
||||||
|
if len(non_drop_g_indices) > 0:
|
||||||
|
g_pooled[non_drop_g_indices] = nd_g_pooled
|
||||||
|
g_out[non_drop_g_indices] = nd_g_out
|
||||||
|
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
|
||||||
|
|
||||||
|
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
|
||||||
if t5xxl is None or t5_tokens is None:
|
if t5xxl is None or t5_tokens is None:
|
||||||
t5_out = None
|
t5_out = None
|
||||||
|
t5_attn_mask = None
|
||||||
else:
|
else:
|
||||||
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
|
# drop some members of the batch: we do not call t5xxl for dropped members
|
||||||
if drop_t5:
|
batch_size, t5_seq_len = t5_tokens.shape
|
||||||
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype)
|
non_drop_t5_indices = []
|
||||||
if t5_attn_mask is not None:
|
for i in range(t5_tokens.shape[0]):
|
||||||
t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device)
|
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
|
||||||
|
if not drop_t5:
|
||||||
|
non_drop_t5_indices.append(i)
|
||||||
|
|
||||||
|
# filter out dropped members
|
||||||
|
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
|
||||||
|
t5_tokens = t5_tokens[non_drop_t5_indices]
|
||||||
|
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
|
||||||
|
|
||||||
|
# call t5xxl for non-dropped members
|
||||||
|
if len(non_drop_t5_indices) > 0:
|
||||||
|
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
|
||||||
|
nd_t5_out, _ = t5xxl(
|
||||||
|
t5_tokens.to(t5xxl.device),
|
||||||
|
nd_t5_attn_mask if apply_t5_attn_mask else None,
|
||||||
|
return_dict=False,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fill in the dropped members
|
||||||
|
if len(non_drop_t5_indices) == batch_size:
|
||||||
|
t5_out = nd_t5_out
|
||||||
else:
|
else:
|
||||||
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
|
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
|
||||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
|
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
|
||||||
|
if len(non_drop_t5_indices) > 0:
|
||||||
|
t5_out[non_drop_t5_indices] = nd_t5_out
|
||||||
|
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
|
||||||
|
|
||||||
# masks are used for attention masking in transformer
|
# masks are used for attention masking in transformer
|
||||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||||
@@ -322,6 +377,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
|||||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -300,7 +300,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
noisy_model_input.requires_grad_(True)
|
noisy_model_input.requires_grad_(True)
|
||||||
for t in text_encoder_conds:
|
for t in text_encoder_conds:
|
||||||
if t.dtype.is_floating_point:
|
if t is not None and t.dtype.is_floating_point:
|
||||||
t.requires_grad_(True)
|
t.requires_grad_(True)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
@@ -415,13 +415,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
prepare_fp8(text_encoder, weight_dtype)
|
prepare_fp8(text_encoder, weight_dtype)
|
||||||
|
|
||||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
# # drop cached text encoder outputs
|
# drop cached text encoder outputs
|
||||||
# text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
# if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
# text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
# text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||||
# batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
|||||||
@@ -1151,16 +1151,6 @@ class NetworkTrainer:
|
|||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
|
||||||
|
|
||||||
# if text_encoder_outputs_list is not None:
|
|
||||||
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list
|
|
||||||
# for i in range(len(lg_out)):
|
|
||||||
# print(
|
|
||||||
# f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
|
|
||||||
# f"cached T5: {t5_out[i].max()}, "
|
|
||||||
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
|
|
||||||
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||||
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
|
||||||
@@ -1193,15 +1183,6 @@ class NetworkTrainer:
|
|||||||
if encoded_text_encoder_conds[i] is not None:
|
if encoded_text_encoder_conds[i] is not None:
|
||||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||||
|
|
||||||
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
|
|
||||||
# for i in range(len(lg_out)):
|
|
||||||
# print(
|
|
||||||
# f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
|
|
||||||
# f"train T5: {t5_out[i].max()}, "
|
|
||||||
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
|
|
||||||
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# sample noise, call unet, get target
|
# sample noise, call unet, get target
|
||||||
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
|
||||||
args,
|
args,
|
||||||
|
|||||||
Reference in New Issue
Block a user