fix: refactor LoHa and LoKr imports for weight merging in load_safetensors_with_lora_and_fp8 function

This commit is contained in:
Kohya S
2026-02-23 21:53:07 +09:00
parent 62db7eb205
commit c11d94727d

View File

@@ -6,6 +6,9 @@ from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from networks.loha import merge_weights_to_tensor as loha_merge
from networks.lokr import merge_weights_to_tensor as lokr_merge
from library.utils import setup_logging
setup_logging()
@@ -191,14 +194,10 @@ def load_safetensors_with_lora_and_fp8(
if hada_key in lora_weight_keys:
# LoHa merge
from networks.loha import merge_weights_to_tensor as loha_merge
model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
break
elif lokr_key in lora_weight_keys:
# LoKr merge
from networks.lokr import merge_weights_to_tensor as lokr_merge
model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
break