From 7367584e6749448cb9b012df0d3bcbe4f0531ea5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 17 Aug 2024 14:38:34 +0900 Subject: [PATCH] fix sd3 training to work without cachine TE outputs #1465 --- sd3_train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sd3_train.py b/sd3_train.py index 9c37cbce..3b6c8a11 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -759,8 +759,9 @@ def train(args): input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"] with torch.set_grad_enabled(args.train_text_encoder): # TODO support weighted captions - input_ids_clip_l = input_ids_clip_l.to(accelerator.device) - input_ids_clip_g = input_ids_clip_g.to(accelerator.device) + # text models in sd3_models require "cpu" for input_ids + input_ids_clip_l = input_ids_clip_l.to("cpu") + input_ids_clip_g = input_ids_clip_g.to("cpu") lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [clip_l, clip_g, None], @@ -770,7 +771,7 @@ def train(args): if t5_out is None: _, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"] with torch.no_grad(): - input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None + input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None _, t5_out, _ = text_encoding_strategy.encode_tokens( sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask] )