Fix device issue in load_file, reduce vram usage

This commit is contained in:
Kohya S
2023-03-31 09:05:51 +09:00
parent ea1cf4acee
commit 8cecc676cf
3 changed files with 21 additions and 11 deletions

View File

@@ -841,7 +841,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device):
if is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, device)
state_dict = load_file(ckpt_path) # , device) # may causes error
else:
checkpoint = torch.load(ckpt_path, map_location=device)
if "state_dict" in checkpoint: