From 1065dd1b56b4b18e211d3827fe22b459c81dd12c Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sun, 27 Oct 2024 19:36:36 +0900 Subject: [PATCH] Fix to work dropout_rate for TEs --- flux_train_network.py | 2 +- library/strategy_flux.py | 1 + library/strategy_sd3.py | 142 +++++++++++++++++++++++++++------------ sd3_train_network.py | 15 ++--- train_network.py | 19 ------ 5 files changed, 108 insertions(+), 71 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index cffeb3b1..2b71a897 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -363,7 +363,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - if t.dtype.is_floating_point: + if t is not None and t.dtype.is_floating_point: t.requires_grad_(True) img_ids.requires_grad_(True) guidance_vec.requires_grad_(True) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 0b0c34af..f662b62e 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -190,6 +190,7 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): apply_t5_attn_mask=apply_t5_attn_mask_i, ) else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i) diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index e57bb337..413169ec 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -89,19 +89,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): if apply_t5_attn_mask is None: apply_t5_attn_mask = self.apply_t5_attn_mask - l_tokens, g_tokens, t5_tokens = tokens[:3] - - if len(tokens) > 3: - l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] - if not apply_lg_attn_mask: - l_attn_mask = None - g_attn_mask = None - if not apply_t5_attn_mask: - t5_attn_mask = None - else: - l_attn_mask = None - g_attn_mask = None - t5_attn_mask = None + l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings @@ -109,47 +97,114 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): assert g_tokens is None, "g_tokens must be None if l_tokens is None" lg_out = None lg_pooled = None + l_attn_mask = None + g_attn_mask = None else: assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) - if drop_l: - l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype) - l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype) - if l_attn_mask is not None: - l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device) - else: - l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None - prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) - l_pooled = prompt_embeds[0] - l_out = prompt_embeds.hidden_states[-2] + # drop some members of the batch: we do not call clip_l and clip_g for dropped members + batch_size, l_seq_len = l_tokens.shape + g_seq_len = g_tokens.shape[1] - drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) - if drop_g: - g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype) - g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype) - if g_attn_mask is not None: - g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device) - else: - g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None - prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) - g_pooled = prompt_embeds[0] - g_out = prompt_embeds.hidden_states[-2] + non_drop_l_indices = [] + non_drop_g_indices = [] + for i in range(l_tokens.shape[0]): + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) + drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) + if not drop_l: + non_drop_l_indices.append(i) + if not drop_g: + non_drop_g_indices.append(i) - lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None + # filter out dropped members + if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: + l_tokens = l_tokens[non_drop_l_indices] + l_attn_mask = l_attn_mask[non_drop_l_indices] + if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: + g_tokens = g_tokens[non_drop_g_indices] + g_attn_mask = g_attn_mask[non_drop_g_indices] + + # call clip_l for non-dropped members + if len(non_drop_l_indices) > 0: + nd_l_attn_mask = l_attn_mask.to(clip_l.device) + prompt_embeds = clip_l( + l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_l_pooled = prompt_embeds[0] + nd_l_out = prompt_embeds.hidden_states[-2] + if len(non_drop_g_indices) > 0: + nd_g_attn_mask = g_attn_mask.to(clip_g.device) + prompt_embeds = clip_g( + g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True + ) + nd_g_pooled = prompt_embeds[0] + nd_g_out = prompt_embeds.hidden_states[-2] + + # fill in the dropped members + if len(non_drop_l_indices) == batch_size: + l_pooled = nd_l_pooled + l_out = nd_l_out + else: + # model output is always float32 because of the models are wrapped with Accelerator + l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32) + l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32) + l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype) + if len(non_drop_l_indices) > 0: + l_pooled[non_drop_l_indices] = nd_l_pooled + l_out[non_drop_l_indices] = nd_l_out + l_attn_mask[non_drop_l_indices] = nd_l_attn_mask + + if len(non_drop_g_indices) == batch_size: + g_pooled = nd_g_pooled + g_out = nd_g_out + else: + g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32) + g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32) + g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype) + if len(non_drop_g_indices) > 0: + g_pooled[non_drop_g_indices] = nd_g_pooled + g_out[non_drop_g_indices] = nd_g_out + g_attn_mask[non_drop_g_indices] = nd_g_attn_mask + + lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) lg_out = torch.cat([l_out, g_out], dim=-1) if t5xxl is None or t5_tokens is None: t5_out = None + t5_attn_mask = None else: - drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) - if drop_t5: - t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype) - if t5_attn_mask is not None: - t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device) + # drop some members of the batch: we do not call t5xxl for dropped members + batch_size, t5_seq_len = t5_tokens.shape + non_drop_t5_indices = [] + for i in range(t5_tokens.shape[0]): + drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) + if not drop_t5: + non_drop_t5_indices.append(i) + + # filter out dropped members + if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size: + t5_tokens = t5_tokens[non_drop_t5_indices] + t5_attn_mask = t5_attn_mask[non_drop_t5_indices] + + # call t5xxl for non-dropped members + if len(non_drop_t5_indices) > 0: + nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device) + nd_t5_out, _ = t5xxl( + t5_tokens.to(t5xxl.device), + nd_t5_attn_mask if apply_t5_attn_mask else None, + return_dict=False, + output_hidden_states=True, + ) + + # fill in the dropped members + if len(non_drop_t5_indices) == batch_size: + t5_out = nd_t5_out else: - t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None - t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) + t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32) + t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype) + if len(non_drop_t5_indices) > 0: + t5_out[non_drop_t5_indices] = nd_t5_out + t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask # masks are used for attention masking in transformer return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] @@ -322,6 +377,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): apply_t5_attn_mask=apply_t5_attn_mask, ) else: + # it's fine that attn mask is not None. it's overwritten before calling the model if necessary info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i) diff --git a/sd3_train_network.py b/sd3_train_network.py index 620a336f..3506404a 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -300,7 +300,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) for t in text_encoder_conds: - if t.dtype.is_floating_point: + if t is not None and t.dtype.is_floating_point: t.requires_grad_(True) # Predict the noise residual @@ -415,13 +415,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # # drop cached text encoder outputs - # text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - # if text_encoder_outputs_list is not None: - # text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - # text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) - # batch["text_encoder_outputs_list"] = text_encoder_outputs_list - pass + # drop cached text encoder outputs + text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + if text_encoder_outputs_list is not None: + text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index 76936b2e..b90aa420 100644 --- a/train_network.py +++ b/train_network.py @@ -1151,16 +1151,6 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs - - # if text_encoder_outputs_list is not None: - # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list - # for i in range(len(lg_out)): - # print( - # f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " - # f"cached T5: {t5_out[i].max()}, " - # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," - # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" - # ) if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' @@ -1193,15 +1183,6 @@ class NetworkTrainer: if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] - # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds - # for i in range(len(lg_out)): - # print( - # f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " - # f"train T5: {t5_out[i].max()}, " - # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," - # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" - # ) - # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( args,