support sai model spec

This commit is contained in:
Kohya S
2023-08-06 21:50:05 +09:00
parent cd54af019a
commit c142dadb46
15 changed files with 746 additions and 64 deletions

View File

@@ -8,6 +8,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
@@ -134,7 +135,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):