mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use CLIPTextModelWithProjection
This commit is contained in:
@@ -30,13 +30,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder2,
|
||||
vae,
|
||||
unet,
|
||||
text_projection,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||
|
||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||
self.text_projection = text_projection.to(accelerator.device, dtype=weight_dtype)
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
@@ -116,7 +114,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[1],
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
)
|
||||
pool2 = pool2 @ self.text_projection.to(pool2.dtype)
|
||||
else:
|
||||
encoder_hidden_states1 = []
|
||||
encoder_hidden_states2 = []
|
||||
@@ -132,8 +129,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
encoder_hidden_states2 = torch.stack(encoder_hidden_states2).to(accelerator.device).to(weight_dtype)
|
||||
pool2 = torch.stack(pool2).to(accelerator.device).to(weight_dtype)
|
||||
|
||||
pool2 = pool2 @ self.text_projection.to(weight_dtype)
|
||||
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
||||
|
||||
Reference in New Issue
Block a user