feat: Update direct loading fp8 ckpt for LoRA training

This commit is contained in:
Kohya S
2024-08-27 21:40:02 +09:00
parent 0087a46e14
commit 3be712e3e0
6 changed files with 151 additions and 55 deletions

View File

@@ -1,5 +1,5 @@
import json
from typing import Union
from typing import Optional, Union
import einops
import torch
@@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1"
# temporary copy from sd3_utils TODO refactor
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
@@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap:
def load_flow_model(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.Flux:
logger.info(f"Building Flux model {name}")
with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
model = flux_models.Flux(flux_models.configs[name].params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")