mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
refactor to make safer
This commit is contained in:
@@ -44,11 +44,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
|||||||
def detect_method_from_training_model(models, dtype):
|
def detect_method_from_training_model(models, dtype):
|
||||||
for model in models:
|
for model in models:
|
||||||
lora_sd, _ = load_state_dict(model, dtype)
|
lora_sd, _ = load_state_dict(model, dtype)
|
||||||
for key in tqdm(lora_sd.keys()):
|
for key in tqdm(lora_sd.keys()):
|
||||||
if 'lora_up' in key or 'lora_down' in key:
|
if 'lora_up' in key or 'lora_down' in key:
|
||||||
return 'LoRA'
|
return 'LoRA'
|
||||||
elif "oft_blocks" in key:
|
elif "oft_blocks" in key:
|
||||||
return 'OFT'
|
return 'OFT'
|
||||||
|
|
||||||
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
||||||
text_encoder1.to(merge_dtype)
|
text_encoder1.to(merge_dtype)
|
||||||
@@ -76,6 +76,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
)
|
)
|
||||||
elif method == 'OFT':
|
elif method == 'OFT':
|
||||||
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
|
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
|
||||||
|
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
|
||||||
target_replace_modules = (
|
target_replace_modules = (
|
||||||
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
)
|
)
|
||||||
@@ -83,16 +84,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
for name, module in root_module.named_modules():
|
for name, module in root_module.named_modules():
|
||||||
if module.__class__.__name__ in target_replace_modules:
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
for child_name, child_module in module.named_modules():
|
for child_name, child_module in module.named_modules():
|
||||||
if method == 'LoRA':
|
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
lora_name = prefix + "." + name + "." + child_name
|
||||||
lora_name = prefix + "." + name + "." + child_name
|
lora_name = lora_name.replace(".", "_")
|
||||||
lora_name = lora_name.replace(".", "_")
|
name_to_module[lora_name] = child_module
|
||||||
name_to_module[lora_name] = child_module
|
|
||||||
elif method == 'OFT':
|
|
||||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
|
||||||
oft_name = prefix + "." + name + "." + child_name
|
|
||||||
oft_name = oft_name.replace(".", "_")
|
|
||||||
name_to_module[oft_name] = child_module
|
|
||||||
|
|
||||||
|
|
||||||
for model, ratio in zip(models, ratios):
|
for model, ratio in zip(models, ratios):
|
||||||
@@ -168,6 +163,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
# find original module for this OFT
|
# find original module for this OFT
|
||||||
module_name = ".".join(key.split(".")[:-1])
|
module_name = ".".join(key.split(".")[:-1])
|
||||||
if module_name not in name_to_module:
|
if module_name not in name_to_module:
|
||||||
|
logger.info(f"no module found for OFT weight: {key}")
|
||||||
return
|
return
|
||||||
module = name_to_module[module_name]
|
module = name_to_module[module_name]
|
||||||
|
|
||||||
@@ -208,7 +204,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
|||||||
|
|
||||||
module.weight = torch.nn.Parameter(weight)
|
module.weight = torch.nn.Parameter(weight)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
|
||||||
|
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user