mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
search block-wise application weights
This commit is contained in:
@@ -511,7 +511,9 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
print(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
)
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -1223,3 +1225,40 @@ class LoRANetwork(torch.nn.Module):
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
# region application weight
|
||||
|
||||
def get_number_of_blocks(self):
|
||||
# only for SDXL
|
||||
return 20
|
||||
|
||||
def has_text_encoder_block(self):
|
||||
return self.text_encoder_loras is not None and len(self.text_encoder_loras) > 0
|
||||
|
||||
def set_block_wise_weights(self, weights):
|
||||
if self.text_encoder_loras:
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.multiplier = weights[0]
|
||||
|
||||
for lora in self.unet_loras:
|
||||
# determine block index
|
||||
key = lora.lora_name[10:] # remove "lora_unet_"
|
||||
if key.startswith("input_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 1 # 1-9
|
||||
elif key.startswith("middle_block"):
|
||||
block_index = 10 # int(key.split("_")[2]) + 10
|
||||
elif key.startswith("output_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 11 # 11-19
|
||||
else:
|
||||
print(f"unknown block: {key}")
|
||||
block_index = 0
|
||||
|
||||
lora.multiplier = weights[block_index]
|
||||
|
||||
# print(f"{lora.lora_name} block index: {block_index}, weight: {lora.multiplier}")
|
||||
# print(f"set block-wise weights to {weights}")
|
||||
|
||||
# TODO LoRA の weight をあらかじめ計算しておいて multiplier を掛けるだけにすると速くなるはず
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user