refactor to make safer

This commit is contained in:
Kohya S
2024-09-13 19:45:00 +09:00
parent 1d7118a622
commit 57ae44eb61

View File

@@ -44,11 +44,11 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
def detect_method_from_training_model(models, dtype):
for model in models:
lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()):
if 'lora_up' in key or 'lora_down' in key:
return 'LoRA'
elif "oft_blocks" in key:
return 'OFT'
for key in tqdm(lora_sd.keys()):
if 'lora_up' in key or 'lora_down' in key:
return 'LoRA'
elif "oft_blocks" in key:
return 'OFT'
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
@@ -76,6 +76,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
)
elif method == 'OFT':
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
target_replace_modules = (
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
@@ -83,16 +84,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if method == 'LoRA':
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
elif method == 'OFT':
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
name_to_module[oft_name] = child_module
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
@@ -168,6 +163,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# find original module for this OFT
module_name = ".".join(key.split(".")[:-1])
if module_name not in name_to_module:
logger.info(f"no module found for OFT weight: {key}")
return
module = name_to_module[module_name]
@@ -208,7 +204,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
module.weight = torch.nn.Parameter(weight)
with concurrent.futures.ThreadPoolExecutor() as executor:
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))