use original unet for HF models, don't download TE

This commit is contained in:
Kohya S
2023-06-14 22:26:05 +09:00
parent 44404fcd6d
commit 449ad7502c
3 changed files with 52 additions and 11 deletions

View File

@@ -99,12 +99,6 @@ from library.original_unet import FlashAttentionFunction
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEFAULT_TOKEN_LENGTH = 75
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
@@ -2066,6 +2060,17 @@ def main(args):
tokenizer = loading_pipe.tokenizer
del loading_pipe
# Diffusers U-Net to original U-Net
original_unet = UNet2DConditionModel(
unet.config.sample_size,
unet.config.attention_head_dim,
unet.config.cross_attention_dim,
unet.config.use_linear_projection,
unet.config.upcast_attention,
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)