mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
6 Commits
c3e3f17cc0
...
f81078c682
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f81078c682 | ||
|
|
145fed65ee | ||
|
|
a21b6a917e | ||
|
|
4625b34f4e | ||
|
|
3b25de1f17 | ||
|
|
f0b07c52ab |
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: kohya-ss
|
||||
@@ -957,8 +957,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
else:
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
|
||||
@@ -240,7 +240,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if LORAFMT[0] in key:
|
||||
block_down_name = key.rsplit(f".LORAFMT[0]", 1)[0]
|
||||
block_down_name = key.rsplit(f".{LORAFMT[0]}", 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
@@ -248,7 +248,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + f".LORAFMT[1]." + weight_name, None)
|
||||
lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}." + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
|
||||
|
||||
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
|
||||
@@ -286,8 +286,8 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
verbose_str += "\n"
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + f".LORAFMT[0].weight"] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + f".LORAFMT[1].weight"] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_down_name + f".{LORAFMT[0]}.weight"] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + f".{LORAFMT[1]}.weight"] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
|
||||
Reference in New Issue
Block a user