mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use original unet for HF models, don't download TE
This commit is contained in:
@@ -99,12 +99,6 @@ from library.original_unet import FlashAttentionFunction
|
||||
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||
|
||||
DEFAULT_TOKEN_LENGTH = 75
|
||||
|
||||
# scheduler:
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
@@ -2066,6 +2060,17 @@ def main(args):
|
||||
tokenizer = loading_pipe.tokenizer
|
||||
del loading_pipe
|
||||
|
||||
# Diffusers U-Net to original U-Net
|
||||
original_unet = UNet2DConditionModel(
|
||||
unet.config.sample_size,
|
||||
unet.config.attention_head_dim,
|
||||
unet.config.cross_attention_dim,
|
||||
unet.config.use_linear_projection,
|
||||
unet.config.upcast_attention,
|
||||
)
|
||||
original_unet.load_state_dict(unet.state_dict())
|
||||
unet = original_unet
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
vae = model_util.load_vae(args.vae, dtype)
|
||||
|
||||
Reference in New Issue
Block a user