mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
fix: --fp8_vl to work
This commit is contained in:
@@ -250,7 +250,7 @@ def sample_image_inference(
|
|||||||
arg_c_null = None
|
arg_c_null = None
|
||||||
|
|
||||||
gen_args = SimpleNamespace(
|
gen_args = SimpleNamespace(
|
||||||
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale
|
image_size=(height, width), infer_steps=sample_steps, flow_shift=flow_shift, guidance_scale=cfg_scale, fp8=args.fp8_scaled
|
||||||
)
|
)
|
||||||
|
|
||||||
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
|
from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from transformers.models.t5.modeling_t5 import T5Stack
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from library.safetensors_utils import load_safetensors
|
from library.safetensors_utils import load_safetensors
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -542,7 +542,6 @@ def get_qwen_prompt_embeds_from_tokens(
|
|||||||
attention_mask = attention_mask.to(device=device)
|
attention_mask = attention_mask.to(device=device)
|
||||||
|
|
||||||
if dtype.itemsize == 1: # fp8
|
if dtype.itemsize == 1: # fp8
|
||||||
# TODO dtype should be vlm.dtype?
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
|
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
|
||||||
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||||
else:
|
else:
|
||||||
@@ -564,7 +563,7 @@ def get_qwen_prompt_embeds_from_tokens(
|
|||||||
|
|
||||||
prompt_embeds = hidden_states[:, drop_idx:, :]
|
prompt_embeds = hidden_states[:, drop_idx:, :]
|
||||||
encoder_attention_mask = attention_mask[:, drop_idx:]
|
encoder_attention_mask = attention_mask[:, drop_idx:]
|
||||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
prompt_embeds = prompt_embeds.to(device=device)
|
||||||
|
|
||||||
return prompt_embeds, encoder_attention_mask
|
return prompt_embeds, encoder_attention_mask
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user