fix: fp8 casting not working

This commit is contained in:
Kohya S
2025-09-18 21:20:08 +09:00
parent f5b004009e
commit 2ce506e187

View File

@@ -284,7 +284,7 @@ def load_dit_model(
# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast)
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)
info = model.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"Loaded FP8 optimized weights: {info}")
@@ -689,15 +689,18 @@ def generate_body(
# print(f"mask_byt5 shape: {mask_byt5.shape}, sum: {mask_byt5.sum()}")
# print(f"negative_mask shape: {negative_mask.shape}, sum: {negative_mask.sum()}")
# print(f"negative_mask_byt5 shape: {negative_mask_byt5.shape}, sum: {negative_mask_byt5.sum()}")
autocast_enabled = args.fp8
with tqdm(total=len(timesteps), desc="Denoising steps") as pbar:
for i, t in enumerate(timesteps):
t_expand = t.expand(latents.shape[0]).to(torch.int64)
with torch.no_grad():
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
noise_pred = model(latents, t_expand, embed, mask, embed_byt5, mask_byt5)
if do_cfg:
with torch.no_grad():
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
uncond_noise_pred = model(
latents, t_expand, negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5
)