diff --git a/library/strategy_sd3.py b/library/strategy_sd3.py index d87ad7d1..e57bb337 100644 --- a/library/strategy_sd3.py +++ b/library/strategy_sd3.py @@ -111,13 +111,13 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): lg_pooled = None else: assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" - + drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) if drop_l: - l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype) - l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype) + l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype) + l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype) if l_attn_mask is not None: - l_attn_mask = torch.zeros_like(l_attn_mask) + l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device) else: l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) @@ -126,10 +126,10 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) if drop_g: - g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype) - g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype) + g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype) + g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype) if g_attn_mask is not None: - g_attn_mask = torch.zeros_like(g_attn_mask) + g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device) else: g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) @@ -144,9 +144,9 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): else: drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) if drop_t5: - t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype) + t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype) if t5_attn_mask is not None: - t5_attn_mask = torch.zeros_like(t5_attn_mask) + t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device) else: t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) @@ -187,7 +187,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy): if t5_attn_mask is not None: t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) - return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask + return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] def concat_encodings( self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor diff --git a/sd3_train_network.py b/sd3_train_network.py index 7b547127..620a336f 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -125,7 +125,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): args.apply_t5_attn_mask, args.clip_l_dropout_rate, args.clip_g_dropout_rate, - args.t5xxl_dropout_rate, + args.t5_dropout_rate, ) def post_process_network(self, args, accelerator, network, text_encoders, unet): @@ -415,12 +415,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): prepare_fp8(text_encoder, weight_dtype) def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): - # drop cached text encoder outputs - text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) - if text_encoder_outputs_list is not None: - text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() - text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list) - batch["text_encoder_outputs_list"] = text_encoder_outputs_list + # # drop cached text encoder outputs + # text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + # if text_encoder_outputs_list is not None: + # text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() + # text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) + # batch["text_encoder_outputs_list"] = text_encoder_outputs_list + pass def setup_parser() -> argparse.ArgumentParser: diff --git a/train_network.py b/train_network.py index 9d78a4ef..76936b2e 100644 --- a/train_network.py +++ b/train_network.py @@ -1151,6 +1151,17 @@ class NetworkTrainer: text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs + + # if text_encoder_outputs_list is not None: + # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list + # for i in range(len(lg_out)): + # print( + # f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " + # f"cached T5: {t5_out[i].max()}, " + # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," + # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" + # ) + if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): @@ -1182,6 +1193,15 @@ class NetworkTrainer: if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] + # lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds + # for i in range(len(lg_out)): + # print( + # f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, " + # f"train T5: {t5_out[i].max()}, " + # f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0}," + # f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}" + # ) + # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( args,