mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54500b861d | ||
|
|
f2491ee0ac | ||
|
|
1f169ee7fb | ||
|
|
66817992c1 | ||
|
|
8052bcd5cd | ||
|
|
55886a0116 | ||
|
|
33e90cc6a0 | ||
|
|
a0e05fa291 | ||
|
|
e33c007cd0 | ||
|
|
9d678a6f41 |
@@ -219,8 +219,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
for key, value in tqdm(lora_sd.items()):
|
for key, value in tqdm(lora_sd.items()):
|
||||||
weight_name = None
|
weight_name = None
|
||||||
if 'lora_down' in key:
|
if 'lora_down' in key:
|
||||||
block_down_name = key.split(".")[0]
|
block_down_name = key.rsplit('.lora_down', 1)[0]
|
||||||
weight_name = key.split(".")[-1]
|
weight_name = key.rsplit(".", 1)[-1]
|
||||||
lora_down_weight = value
|
lora_down_weight = value
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
@@ -283,7 +283,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
|||||||
|
|
||||||
|
|
||||||
def resize(args):
|
def resize(args):
|
||||||
|
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
||||||
|
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||||
|
|
||||||
|
|
||||||
def str_to_dtype(p):
|
def str_to_dtype(p):
|
||||||
if p == 'float':
|
if p == 'float':
|
||||||
return torch.float
|
return torch.float
|
||||||
|
|||||||
Reference in New Issue
Block a user