feat: enhance LoRA weight handling in model loading and add text encoder loading function

This commit is contained in:
Kohya S
2026-02-12 22:24:48 +09:00
parent 56e660dfde
commit 326c425a5b
3 changed files with 60 additions and 14 deletions

View File

@@ -244,6 +244,7 @@ def load_dit_model(
logger.info(f"Loading LoRA weight from: {lora_weight}")
lora_sd = load_file(lora_weight) # load on CPU, dtype is as is
# lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns)
lora_sd = {k: v for k, v in lora_sd.items() if k.startswith("lora_unet_")} # only keep unet lora weights
lora_weights_list.append(lora_sd)
else:
lora_weights_list = None
@@ -284,6 +285,28 @@ def load_dit_model(
return model
def load_text_encoder(
args: argparse.Namespace, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
) -> torch.nn.Module:
lora_weights_list = None
if args.lora_weight is not None and len(args.lora_weight) > 0:
lora_weights_list = []
for lora_weight in args.lora_weight:
logger.info(f"Loading LoRA weight from: {lora_weight}")
lora_sd = load_file(lora_weight) # load on CPU, dtype is as is
# lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns)
lora_sd = {
"model_" + k[len("lora_te_") :]: v for k, v in lora_sd.items() if k.startswith("lora_te_")
} # only keep Text Encoder lora weights, remove prefix "lora_te_" and add "model_" prefix
lora_weights_list.append(lora_sd)
text_encoder, _ = anima_utils.load_qwen3_text_encoder(
args.text_encoder, dtype=dtype, device=device, lora_weights=lora_weights_list, lora_multipliers=args.lora_multiplier
)
text_encoder.eval()
return text_encoder
# endregion
@@ -305,6 +328,7 @@ def decode_latent(
logger.info(f"Decoded. Pixel shape {pixels.shape}")
return pixels[0] # remove batch dimension
def process_escape(text: str) -> str:
"""Process escape sequences in text
@@ -316,6 +340,7 @@ def process_escape(text: str) -> str:
"""
return text.encode("utf-8").decode("unicode_escape")
def prepare_text_inputs(
args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -333,9 +358,7 @@ def prepare_text_inputs(
# text_encoder is on device (batched inference) or CPU (interactive inference)
else: # Load if not in shared_models
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
text_encoder, _ = anima_utils.load_qwen3_text_encoder(
args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device
)
text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=text_encoder_device)
text_encoder.eval()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
# Store references so load_target_model can reuse them
@@ -721,7 +744,7 @@ def load_shared_models(args: argparse.Namespace) -> Dict:
shared_models = {}
# Load text encoders to CPU
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=torch.device("cpu"))
shared_models["text_encoder"] = text_encoder
return shared_models
@@ -766,7 +789,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
# Text Encoder loaded to CPU by load_text_encoder
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
text_encoder_batch, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
text_encoder_batch = load_text_encoder(args, dtype=text_encoder_dtype, device=torch.device("cpu"))
# Text Encoder to device for this phase
text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device

View File

@@ -178,7 +178,13 @@ def load_qwen3_tokenizer(qwen3_path: str):
return tokenizer
def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu"):
def load_qwen3_text_encoder(
qwen3_path: str,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu",
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[List[float]] = None,
):
"""Load Qwen3-0.6B text encoder.
Args:
@@ -214,8 +220,20 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
# Load weights
if qwen3_path.endswith(".safetensors"):
state_dict = load_file(qwen3_path, device="cpu")
if lora_weights is None:
state_dict = load_file(qwen3_path, device="cpu")
else:
state_dict = load_safetensors_with_lora_and_fp8(
model_files=qwen3_path,
lora_weights_list=lora_weights,
lora_multipliers=lora_multipliers,
fp8_optimization=False,
calc_device=device,
move_to_device=True,
dit_weight_dtype=None,
)
else:
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present

View File

@@ -120,13 +120,18 @@ def load_safetensors_with_lora_and_fp8(
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
# check if this weight has LoRA weights
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
lora_name = "lora_unet_" + lora_name.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
continue
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
found = False
for prefix in ["lora_unet_", ""]:
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key in lora_weight_keys and up_key in lora_weight_keys:
found = True
break
if not found:
continue # no LoRA weights for this model weight
# get LoRA weights
down_weight = lora_sd[down_key]