mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
feat: enhance LoRA weight handling in model loading and add text encoder loading function
This commit is contained in:
@@ -244,6 +244,7 @@ def load_dit_model(
|
||||
logger.info(f"Loading LoRA weight from: {lora_weight}")
|
||||
lora_sd = load_file(lora_weight) # load on CPU, dtype is as is
|
||||
# lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns)
|
||||
lora_sd = {k: v for k, v in lora_sd.items() if k.startswith("lora_unet_")} # only keep unet lora weights
|
||||
lora_weights_list.append(lora_sd)
|
||||
else:
|
||||
lora_weights_list = None
|
||||
@@ -284,6 +285,28 @@ def load_dit_model(
|
||||
return model
|
||||
|
||||
|
||||
def load_text_encoder(
|
||||
args: argparse.Namespace, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
|
||||
) -> torch.nn.Module:
|
||||
lora_weights_list = None
|
||||
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
||||
lora_weights_list = []
|
||||
for lora_weight in args.lora_weight:
|
||||
logger.info(f"Loading LoRA weight from: {lora_weight}")
|
||||
lora_sd = load_file(lora_weight) # load on CPU, dtype is as is
|
||||
# lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns)
|
||||
lora_sd = {
|
||||
"model_" + k[len("lora_te_") :]: v for k, v in lora_sd.items() if k.startswith("lora_te_")
|
||||
} # only keep Text Encoder lora weights, remove prefix "lora_te_" and add "model_" prefix
|
||||
lora_weights_list.append(lora_sd)
|
||||
|
||||
text_encoder, _ = anima_utils.load_qwen3_text_encoder(
|
||||
args.text_encoder, dtype=dtype, device=device, lora_weights=lora_weights_list, lora_multipliers=args.lora_multiplier
|
||||
)
|
||||
text_encoder.eval()
|
||||
return text_encoder
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -305,6 +328,7 @@ def decode_latent(
|
||||
logger.info(f"Decoded. Pixel shape {pixels.shape}")
|
||||
return pixels[0] # remove batch dimension
|
||||
|
||||
|
||||
def process_escape(text: str) -> str:
|
||||
"""Process escape sequences in text
|
||||
|
||||
@@ -316,6 +340,7 @@ def process_escape(text: str) -> str:
|
||||
"""
|
||||
return text.encode("utf-8").decode("unicode_escape")
|
||||
|
||||
|
||||
def prepare_text_inputs(
|
||||
args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
@@ -333,9 +358,7 @@ def prepare_text_inputs(
|
||||
# text_encoder is on device (batched inference) or CPU (interactive inference)
|
||||
else: # Load if not in shared_models
|
||||
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
|
||||
text_encoder, _ = anima_utils.load_qwen3_text_encoder(
|
||||
args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device
|
||||
)
|
||||
text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=text_encoder_device)
|
||||
text_encoder.eval()
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
# Store references so load_target_model can reuse them
|
||||
@@ -721,7 +744,7 @@ def load_shared_models(args: argparse.Namespace) -> Dict:
|
||||
shared_models = {}
|
||||
# Load text encoders to CPU
|
||||
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
|
||||
text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
|
||||
text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=torch.device("cpu"))
|
||||
shared_models["text_encoder"] = text_encoder
|
||||
return shared_models
|
||||
|
||||
@@ -766,7 +789,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
|
||||
# Text Encoder loaded to CPU by load_text_encoder
|
||||
text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
|
||||
text_encoder_batch, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
|
||||
text_encoder_batch = load_text_encoder(args, dtype=text_encoder_dtype, device=torch.device("cpu"))
|
||||
|
||||
# Text Encoder to device for this phase
|
||||
text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
|
||||
|
||||
@@ -178,7 +178,13 @@ def load_qwen3_tokenizer(qwen3_path: str):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu"):
|
||||
def load_qwen3_text_encoder(
|
||||
qwen3_path: str,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu",
|
||||
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[List[float]] = None,
|
||||
):
|
||||
"""Load Qwen3-0.6B text encoder.
|
||||
|
||||
Args:
|
||||
@@ -214,8 +220,20 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
|
||||
|
||||
# Load weights
|
||||
if qwen3_path.endswith(".safetensors"):
|
||||
state_dict = load_file(qwen3_path, device="cpu")
|
||||
if lora_weights is None:
|
||||
state_dict = load_file(qwen3_path, device="cpu")
|
||||
else:
|
||||
state_dict = load_safetensors_with_lora_and_fp8(
|
||||
model_files=qwen3_path,
|
||||
lora_weights_list=lora_weights,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=False,
|
||||
calc_device=device,
|
||||
move_to_device=True,
|
||||
dit_weight_dtype=None,
|
||||
)
|
||||
else:
|
||||
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
|
||||
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# Remove 'model.' prefix if present
|
||||
|
||||
@@ -120,13 +120,18 @@ def load_safetensors_with_lora_and_fp8(
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
lora_name = "lora_unet_" + lora_name.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
|
||||
continue
|
||||
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
found = False
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
continue # no LoRA weights for this model weight
|
||||
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
|
||||
Reference in New Issue
Block a user