use CLIPTextModelWithProjection

This commit is contained in:
Kohya S
2023-06-27 20:48:06 +09:00
parent 753c63e11b
commit a751dc25d6
6 changed files with 90 additions and 89 deletions

View File

@@ -30,13 +30,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoder2,
vae,
unet,
text_projection,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
self.load_stable_diffusion_format = load_stable_diffusion_format
self.text_projection = text_projection.to(accelerator.device, dtype=weight_dtype)
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
@@ -116,7 +114,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
)
pool2 = pool2 @ self.text_projection.to(pool2.dtype)
else:
encoder_hidden_states1 = []
encoder_hidden_states2 = []
@@ -132,8 +129,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
pool2 = pool2 @ self.text_projection.to(weight_dtype)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):