use CLIPTextModelWithProjection

This commit is contained in:
Kohya S
2023-06-27 20:48:06 +09:00
parent 753c63e11b
commit a751dc25d6
6 changed files with 90 additions and 89 deletions

View File

@@ -118,11 +118,9 @@ def train(args):
text_encoder2,
vae,
unet,
text_projection,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_projection = text_projection.to(accelerator.device, dtype=weight_dtype)
logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# verify load/save model formats
@@ -379,7 +377,6 @@ def train(args):
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
pool2 = pool2 @ text_projection.to(pool2.dtype)
else:
encoder_hidden_states1 = []
encoder_hidden_states2 = []
@@ -395,8 +392,6 @@ def train(args):
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
pool2 = pool2 @ text_projection.to(pool2.dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
@@ -492,7 +487,6 @@ def train(args):
accelerator.unwrap_model(text_encoder2),
accelerator.unwrap_model(unet),
vae,
text_projection,
logit_scale,
ckpt_info,
)
@@ -541,7 +535,6 @@ def train(args):
accelerator.unwrap_model(text_encoder2),
accelerator.unwrap_model(unet),
vae,
text_projection,
logit_scale,
ckpt_info,
)
@@ -575,7 +568,6 @@ def train(args):
text_encoder2,
unet,
vae,
text_projection,
logit_scale,
ckpt_info,
)