Fix SD3 LoRA training to work (WIP)

This commit is contained in:
kohya-ss
2024-10-27 17:03:36 +09:00
parent db2b4d41b9
commit a1255d637f
3 changed files with 38 additions and 17 deletions

View File

@@ -114,10 +114,10 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
if drop_l: if drop_l:
l_pooled = torch.zeros((l_tokens.shape[0], 768), device=l_tokens.device, dtype=l_tokens.dtype) l_pooled = torch.zeros((l_tokens.shape[0], 768), device=clip_l.device, dtype=clip_l.dtype)
l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=l_tokens.device, dtype=l_tokens.dtype) l_out = torch.zeros((l_tokens.shape[0], l_tokens.shape[1], 768), device=clip_l.device, dtype=clip_l.dtype)
if l_attn_mask is not None: if l_attn_mask is not None:
l_attn_mask = torch.zeros_like(l_attn_mask) l_attn_mask = torch.zeros_like(l_attn_mask, device=clip_l.device)
else: else:
l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None l_attn_mask = l_attn_mask.to(clip_l.device) if l_attn_mask is not None else None
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True) prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
@@ -126,10 +126,10 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if drop_g: if drop_g:
g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=g_tokens.device, dtype=g_tokens.dtype) g_pooled = torch.zeros((g_tokens.shape[0], 1280), device=clip_g.device, dtype=clip_g.dtype)
g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=g_tokens.device, dtype=g_tokens.dtype) g_out = torch.zeros((g_tokens.shape[0], g_tokens.shape[1], 1280), device=clip_g.device, dtype=clip_g.dtype)
if g_attn_mask is not None: if g_attn_mask is not None:
g_attn_mask = torch.zeros_like(g_attn_mask) g_attn_mask = torch.zeros_like(g_attn_mask, device=clip_g.device)
else: else:
g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None g_attn_mask = g_attn_mask.to(clip_g.device) if g_attn_mask is not None else None
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True) prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
@@ -144,9 +144,9 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
else: else:
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate) drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if drop_t5: if drop_t5:
t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5_tokens.device, dtype=t5_tokens.dtype) t5_out = torch.zeros((t5_tokens.shape[0], t5_tokens.shape[1], 4096), device=t5xxl.device, dtype=t5xxl.dtype)
if t5_attn_mask is not None: if t5_attn_mask is not None:
t5_attn_mask = torch.zeros_like(t5_attn_mask) t5_attn_mask = torch.zeros_like(t5_attn_mask, device=t5xxl.device)
else: else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None t5_attn_mask = t5_attn_mask.to(t5xxl.device) if t5_attn_mask is not None else None
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True) t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
@@ -187,7 +187,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
if t5_attn_mask is not None: if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
return lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def concat_encodings( def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor

View File

@@ -125,7 +125,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
args.apply_t5_attn_mask, args.apply_t5_attn_mask,
args.clip_l_dropout_rate, args.clip_l_dropout_rate,
args.clip_g_dropout_rate, args.clip_g_dropout_rate,
args.t5xxl_dropout_rate, args.t5_dropout_rate,
) )
def post_process_network(self, args, accelerator, network, text_encoders, unet): def post_process_network(self, args, accelerator, network, text_encoders, unet):
@@ -415,12 +415,13 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs # # drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) # text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: # if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() # text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(text_encoder_outputs_list) # text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list # batch["text_encoder_outputs_list"] = text_encoder_outputs_list
pass
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:

View File

@@ -1151,6 +1151,17 @@ class NetworkTrainer:
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs
# if text_encoder_outputs_list is not None:
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_outputs_list
# for i in range(len(lg_out)):
# print(
# f"[{i}] cached L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, cached G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
# f"cached T5: {t5_out[i].max()}, "
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
# )
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached' # TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
@@ -1182,6 +1193,15 @@ class NetworkTrainer:
if encoded_text_encoder_conds[i] is not None: if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i] text_encoder_conds[i] = encoded_text_encoder_conds[i]
# lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
# for i in range(len(lg_out)):
# print(
# f"[{i}] train L: {lg_out[i,:,:768].max()}, {lg_pooled[i][:768].max()}, train G: {lg_out[i,:,768:].max()}, {lg_pooled[i][768:].max()}, "
# f"train T5: {t5_out[i].max()}, "
# f"attn mask: {l_attn_mask[i].max() if l_attn_mask is not None else 0},"
# f" {g_attn_mask[i].max() if g_attn_mask is not None else 0}, {t5_attn_mask[i].max() if t5_attn_mask is not None else 0}"
# )
# sample noise, call unet, get target # sample noise, call unet, get target
noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target( noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(
args, args,