mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: Update direct loading fp8 ckpt for LoRA training
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
import einops
|
||||
import torch
|
||||
|
||||
@@ -20,7 +20,9 @@ MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
|
||||
|
||||
# temporary copy from sd3_utils TODO refactor
|
||||
def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: torch.dtype = torch.float32):
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
):
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
@@ -38,11 +40,13 @@ def load_safetensors(path: str, device: Union[str, torch.device], disable_mmap:
|
||||
|
||||
|
||||
def load_flow_model(
|
||||
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
) -> flux_models.Flux:
|
||||
logger.info(f"Building Flux model {name}")
|
||||
with torch.device("meta"):
|
||||
model = flux_models.Flux(flux_models.configs[name].params).to(dtype)
|
||||
model = flux_models.Flux(flux_models.configs[name].params)
|
||||
if dtype is not None:
|
||||
model = model.to(dtype)
|
||||
|
||||
# load_sft doesn't support torch.device
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
|
||||
Reference in New Issue
Block a user