mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
refactor to make safer
This commit is contained in:
@@ -76,6 +76,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
)
|
||||
elif method == 'OFT':
|
||||
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
|
||||
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
|
||||
target_replace_modules = (
|
||||
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
)
|
||||
@@ -83,16 +84,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if method == 'LoRA':
|
||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
name_to_module[lora_name] = child_module
|
||||
elif method == 'OFT':
|
||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||
oft_name = prefix + "." + name + "." + child_name
|
||||
oft_name = oft_name.replace(".", "_")
|
||||
name_to_module[oft_name] = child_module
|
||||
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
@@ -168,6 +163,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
# find original module for this OFT
|
||||
module_name = ".".join(key.split(".")[:-1])
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for OFT weight: {key}")
|
||||
return
|
||||
module = name_to_module[module_name]
|
||||
|
||||
@@ -208,7 +204,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
|
||||
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user