mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
fix: fp8 casting not working
This commit is contained in:
@@ -284,7 +284,7 @@ def load_dit_model(
|
||||
|
||||
# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
|
||||
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
|
||||
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast)
|
||||
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)
|
||||
|
||||
info = model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
logger.info(f"Loaded FP8 optimized weights: {info}")
|
||||
@@ -689,15 +689,18 @@ def generate_body(
|
||||
# print(f"mask_byt5 shape: {mask_byt5.shape}, sum: {mask_byt5.sum()}")
|
||||
# print(f"negative_mask shape: {negative_mask.shape}, sum: {negative_mask.sum()}")
|
||||
# print(f"negative_mask_byt5 shape: {negative_mask_byt5.shape}, sum: {negative_mask_byt5.sum()}")
|
||||
|
||||
autocast_enabled = args.fp8
|
||||
|
||||
with tqdm(total=len(timesteps), desc="Denoising steps") as pbar:
|
||||
for i, t in enumerate(timesteps):
|
||||
t_expand = t.expand(latents.shape[0]).to(torch.int64)
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
|
||||
noise_pred = model(latents, t_expand, embed, mask, embed_byt5, mask_byt5)
|
||||
|
||||
if do_cfg:
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
|
||||
uncond_noise_pred = model(
|
||||
latents, t_expand, negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user