mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
fix: --fp8_vl to work
This commit is contained in:
@@ -15,7 +15,7 @@ from transformers.models.t5.modeling_t5 import T5Stack
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library.safetensors_utils import load_safetensors
|
||||
from library.utils import setup_logging
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -542,7 +542,6 @@ def get_qwen_prompt_embeds_from_tokens(
|
||||
attention_mask = attention_mask.to(device=device)
|
||||
|
||||
if dtype.itemsize == 1: # fp8
|
||||
# TODO dtype should be vlm.dtype?
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
|
||||
encoder_hidden_states = vlm(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
else:
|
||||
@@ -564,7 +563,7 @@ def get_qwen_prompt_embeds_from_tokens(
|
||||
|
||||
prompt_embeds = hidden_states[:, drop_idx:, :]
|
||||
encoder_attention_mask = attention_mask[:, drop_idx:]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user