mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add sd3 models and inference script
This commit is contained in:
1796
library/sd3_models.py
Normal file
1796
library/sd3_models.py
Normal file
File diff suppressed because it is too large
Load Diff
113
library/sd3_utils.py
Normal file
113
library/sd3_utils.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import math
|
||||||
|
from typing import Dict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from library import sd3_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_cond(
|
||||||
|
prompt: str,
|
||||||
|
tokenizer: sd3_models.SD3Tokenizer,
|
||||||
|
clip_l: sd3_models.SDClipModel,
|
||||||
|
clip_g: sd3_models.SDXLClipG,
|
||||||
|
t5xxl: sd3_models.T5XXLModel,
|
||||||
|
):
|
||||||
|
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
|
||||||
|
l_out, l_pooled = clip_l.encode_token_weights(l_tokens)
|
||||||
|
g_out, g_pooled = clip_g.encode_token_weights(g_tokens)
|
||||||
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
|
|
||||||
|
if t5_tokens is None:
|
||||||
|
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device)
|
||||||
|
else:
|
||||||
|
t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None
|
||||||
|
t5_out = t5_out.to(lg_out.dtype)
|
||||||
|
|
||||||
|
return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# used if other sd3 models is available
|
||||||
|
r"""
|
||||||
|
def get_sd3_configs(state_dict: Dict):
|
||||||
|
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||||
|
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||||
|
# prefix = "model.diffusion_model."
|
||||||
|
prefix = ""
|
||||||
|
|
||||||
|
patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2]
|
||||||
|
depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64
|
||||||
|
num_patches = state_dict[prefix + "pos_embed"].shape[1]
|
||||||
|
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||||
|
adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1]
|
||||||
|
context_shape = state_dict[prefix + "context_embedder.weight"].shape
|
||||||
|
context_embedder_config = {
|
||||||
|
"target": "torch.nn.Linear",
|
||||||
|
"params": {"in_features": context_shape[1], "out_features": context_shape[0]},
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"patch_size": patch_size,
|
||||||
|
"depth": depth,
|
||||||
|
"num_patches": num_patches,
|
||||||
|
"pos_embed_max_size": pos_embed_max_size,
|
||||||
|
"adm_in_channels": adm_in_channels,
|
||||||
|
"context_embedder": context_embedder_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"):
|
||||||
|
""
|
||||||
|
Doesn't load state dict.
|
||||||
|
""
|
||||||
|
sd3_configs = get_sd3_configs(state_dict)
|
||||||
|
|
||||||
|
mmdit = sd3_models.MMDiT(
|
||||||
|
input_size=None,
|
||||||
|
pos_embed_max_size=sd3_configs["pos_embed_max_size"],
|
||||||
|
patch_size=sd3_configs["patch_size"],
|
||||||
|
in_channels=16,
|
||||||
|
adm_in_channels=sd3_configs["adm_in_channels"],
|
||||||
|
depth=sd3_configs["depth"],
|
||||||
|
mlp_ratio=4,
|
||||||
|
qk_norm=None,
|
||||||
|
num_patches=sd3_configs["num_patches"],
|
||||||
|
context_size=4096,
|
||||||
|
attn_mode=attn_mode,
|
||||||
|
)
|
||||||
|
return mmdit
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteFlow:
|
||||||
|
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
||||||
|
|
||||||
|
def __init__(self, shift=1.0):
|
||||||
|
self.shift = shift
|
||||||
|
timesteps = 1000
|
||||||
|
self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
return sigma * 1000
|
||||||
|
|
||||||
|
def sigma(self, timestep: torch.Tensor):
|
||||||
|
timestep = timestep / 1000.0
|
||||||
|
if self.shift == 1.0:
|
||||||
|
return timestep
|
||||||
|
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||||
|
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
# assert max_denoise is False, "max_denoise not implemented"
|
||||||
|
# max_denoise is always True, I'm not sure why it's there
|
||||||
|
return sigma * noise + (1.0 - sigma) * latent_image
|
||||||
347
sd3_minimal_inference.py
Normal file
347
sd3_minimal_inference.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
# Minimum Inference Code for SD3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import safe_open, load_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from library.device_utils import init_ipex, get_preferred_device
|
||||||
|
|
||||||
|
init_ipex()
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from library import sd3_models, sd3_utils
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(seed, latent):
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||||
|
start = sampling.timestep(sampling.sigma_max)
|
||||||
|
end = sampling.timestep(sampling.sigma_min)
|
||||||
|
timesteps = torch.linspace(start, end, steps)
|
||||||
|
sigs = []
|
||||||
|
for x in range(len(timesteps)):
|
||||||
|
ts = timesteps[x]
|
||||||
|
sigs.append(sampling.sigma(ts))
|
||||||
|
sigs += [0.0]
|
||||||
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
|
|
||||||
|
def max_denoise(model_sampling, sigmas):
|
||||||
|
max_sigma = float(model_sampling.sigma_max)
|
||||||
|
sigma = float(sigmas[0])
|
||||||
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
initial_latent: Optional[torch.Tensor],
|
||||||
|
seed: int,
|
||||||
|
cond: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
neg_cond: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
mmdit: sd3_models.MMDiT,
|
||||||
|
steps: int,
|
||||||
|
guidance_scale: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
if initial_latent is None:
|
||||||
|
latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609
|
||||||
|
else:
|
||||||
|
latent = initial_latent
|
||||||
|
|
||||||
|
latent = latent.to(dtype).to(device)
|
||||||
|
|
||||||
|
noise = get_noise(seed, latent).to(device)
|
||||||
|
|
||||||
|
model_sampling = sd3_utils.ModelSamplingDiscreteFlow()
|
||||||
|
|
||||||
|
sigmas = get_sigmas(model_sampling, steps).to(device)
|
||||||
|
# sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i
|
||||||
|
|
||||||
|
# conditioning = fix_cond(conditioning)
|
||||||
|
# neg_cond = fix_cond(neg_cond)
|
||||||
|
# extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale}
|
||||||
|
|
||||||
|
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
|
||||||
|
|
||||||
|
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
|
||||||
|
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
|
||||||
|
|
||||||
|
x = noise_scaled.to(device).to(dtype)
|
||||||
|
# print(x.shape)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in tqdm(range(len(sigmas) - 1)):
|
||||||
|
sigma_hat = sigmas[i]
|
||||||
|
|
||||||
|
timestep = model_sampling.timestep(sigma_hat).float()
|
||||||
|
timestep = torch.FloatTensor([timestep, timestep]).to(device)
|
||||||
|
|
||||||
|
x_c_nc = torch.cat([x, x], dim=0)
|
||||||
|
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
|
||||||
|
|
||||||
|
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
|
||||||
|
model_output = model_output.float()
|
||||||
|
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
||||||
|
|
||||||
|
pos_out, neg_out = batched.chunk(2)
|
||||||
|
denoised = neg_out + (pos_out - neg_out) * guidance_scale
|
||||||
|
# print(denoised.shape)
|
||||||
|
|
||||||
|
# d = to_d(x, sigma_hat, denoised)
|
||||||
|
dims_to_append = x.ndim - sigma_hat.ndim
|
||||||
|
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
|
||||||
|
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
|
||||||
|
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||||
|
d = (x - denoised) / sigma_hat_dims
|
||||||
|
|
||||||
|
dt = sigmas[i + 1] - sigma_hat
|
||||||
|
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
x = x.to(dtype)
|
||||||
|
|
||||||
|
latent = x
|
||||||
|
scale_factor = 1.5305
|
||||||
|
shift_factor = 0.0609
|
||||||
|
# def process_out(self, latent):
|
||||||
|
# return (latent / self.scale_factor) + self.shift_factor
|
||||||
|
latent = (latent / scale_factor) + shift_factor
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
target_height = 1024
|
||||||
|
target_width = 1024
|
||||||
|
|
||||||
|
# steps = 50 # 28 # 50
|
||||||
|
guidance_scale = 5
|
||||||
|
# seed = 1 # None # 1
|
||||||
|
|
||||||
|
device = get_preferred_device()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||||
|
parser.add_argument("--clip_g", type=str, required=False)
|
||||||
|
parser.add_argument("--clip_l", type=str, required=False)
|
||||||
|
parser.add_argument("--t5xxl", type=str, required=False)
|
||||||
|
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||||
|
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||||
|
parser.add_argument("--negative_prompt", type=str, default="")
|
||||||
|
parser.add_argument("--output_dir", type=str, default=".")
|
||||||
|
parser.add_argument("--do_not_use_t5xxl", action="store_true")
|
||||||
|
parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
|
||||||
|
parser.add_argument("--fp16", action="store_true")
|
||||||
|
parser.add_argument("--bf16", action="store_true")
|
||||||
|
parser.add_argument("--seed", type=int, default=1)
|
||||||
|
parser.add_argument("--steps", type=int, default=50)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--lora_weights",
|
||||||
|
# type=str,
|
||||||
|
# nargs="*",
|
||||||
|
# default=[],
|
||||||
|
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
|
||||||
|
# )
|
||||||
|
# parser.add_argument("--interactive", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
seed = args.seed
|
||||||
|
steps = args.steps
|
||||||
|
|
||||||
|
sd3_dtype = torch.float32
|
||||||
|
if args.fp16:
|
||||||
|
sd3_dtype = torch.float16
|
||||||
|
elif args.bf16:
|
||||||
|
sd3_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# TODO test with separated safetenors files for each model
|
||||||
|
|
||||||
|
# load state dict
|
||||||
|
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
|
||||||
|
state_dict = load_file(args.ckpt_path)
|
||||||
|
|
||||||
|
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||||
|
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||||
|
logger.info("clip_g is included in the checkpoint")
|
||||||
|
clip_g_sd = {}
|
||||||
|
prefix = "text_encoders.clip_g."
|
||||||
|
for k, v in list(state_dict.items()):
|
||||||
|
if k.startswith(prefix):
|
||||||
|
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||||
|
else:
|
||||||
|
logger.info(f"Lodaing clip_g from {args.clip_g}...")
|
||||||
|
clip_g_sd = load_file(args.clip_g)
|
||||||
|
for key in list(clip_g_sd.keys()):
|
||||||
|
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||||
|
|
||||||
|
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||||
|
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||||
|
logger.info("clip_l is included in the checkpoint")
|
||||||
|
clip_l_sd = {}
|
||||||
|
prefix = "text_encoders.clip_l."
|
||||||
|
for k, v in list(state_dict.items()):
|
||||||
|
if k.startswith(prefix):
|
||||||
|
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||||
|
else:
|
||||||
|
logger.info(f"Lodaing clip_l from {args.clip_l}...")
|
||||||
|
clip_l_sd = load_file(args.clip_l)
|
||||||
|
for key in list(clip_l_sd.keys()):
|
||||||
|
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||||
|
|
||||||
|
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||||
|
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||||
|
logger.info("t5xxl is included in the checkpoint")
|
||||||
|
if not args.do_not_use_t5xxl:
|
||||||
|
t5xxl_sd = {}
|
||||||
|
prefix = "text_encoders.t5xxl."
|
||||||
|
for k, v in list(state_dict.items()):
|
||||||
|
if k.startswith(prefix):
|
||||||
|
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||||
|
else:
|
||||||
|
logger.info("but not used")
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
if key.startswith("text_encoders.t5xxl."):
|
||||||
|
state_dict.pop(key)
|
||||||
|
t5xxl_sd = None
|
||||||
|
elif args.t5xxl:
|
||||||
|
assert not args.do_not_use_t5xxl, "t5xxl is not used but specified"
|
||||||
|
logger.info(f"Lodaing t5xxl from {args.t5xxl}...")
|
||||||
|
t5xxl_sd = load_file(args.t5xxl)
|
||||||
|
for key in list(t5xxl_sd.keys()):
|
||||||
|
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||||
|
else:
|
||||||
|
logger.info("t5xxl is not used")
|
||||||
|
t5xxl_sd = None
|
||||||
|
|
||||||
|
use_t5xxl = t5xxl_sd is not None
|
||||||
|
|
||||||
|
# MMDiT and VAE
|
||||||
|
vae_sd = {}
|
||||||
|
vae_prefix = "first_stage_model."
|
||||||
|
mmdit_prefix = "model.diffusion_model."
|
||||||
|
for k, v in list(state_dict.items()):
|
||||||
|
if k.startswith(vae_prefix):
|
||||||
|
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||||
|
elif k.startswith(mmdit_prefix):
|
||||||
|
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
||||||
|
|
||||||
|
# load tokenizers
|
||||||
|
logger.info("Loading tokenizers...")
|
||||||
|
tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer
|
||||||
|
|
||||||
|
# load models
|
||||||
|
# logger.info("Create MMDiT from SD3 checkpoint...")
|
||||||
|
# mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict)
|
||||||
|
logger.info("Create MMDiT")
|
||||||
|
mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode)
|
||||||
|
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = mmdit.load_state_dict(state_dict)
|
||||||
|
logger.info(f"Loaded MMDiT: {info}")
|
||||||
|
|
||||||
|
logger.info(f"Move MMDiT to {device} and {sd3_dtype}...")
|
||||||
|
mmdit.to(device, dtype=sd3_dtype)
|
||||||
|
mmdit.eval()
|
||||||
|
|
||||||
|
# load VAE
|
||||||
|
logger.info("Create VAE")
|
||||||
|
vae = sd3_models.SDVAE()
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = vae.load_state_dict(vae_sd)
|
||||||
|
logger.info(f"Loaded VAE: {info}")
|
||||||
|
|
||||||
|
logger.info(f"Move VAE to {device} and {sd3_dtype}...")
|
||||||
|
vae.to(device, dtype=sd3_dtype)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
|
# load text encoders
|
||||||
|
logger.info("Create clip_l")
|
||||||
|
clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd)
|
||||||
|
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = clip_l.load_state_dict(clip_l_sd)
|
||||||
|
logger.info(f"Loaded clip_l: {info}")
|
||||||
|
|
||||||
|
logger.info(f"Move clip_l to {device} and {sd3_dtype}...")
|
||||||
|
clip_l.to(device, dtype=sd3_dtype)
|
||||||
|
clip_l.eval()
|
||||||
|
logger.info(f"Set attn_mode to {args.attn_mode}...")
|
||||||
|
clip_l.set_attn_mode(args.attn_mode)
|
||||||
|
|
||||||
|
logger.info("Create clip_g")
|
||||||
|
clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd)
|
||||||
|
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = clip_g.load_state_dict(clip_g_sd)
|
||||||
|
logger.info(f"Loaded clip_g: {info}")
|
||||||
|
|
||||||
|
logger.info(f"Move clip_g to {device} and {sd3_dtype}...")
|
||||||
|
clip_g.to(device, dtype=sd3_dtype)
|
||||||
|
clip_g.eval()
|
||||||
|
logger.info(f"Set attn_mode to {args.attn_mode}...")
|
||||||
|
clip_g.set_attn_mode(args.attn_mode)
|
||||||
|
|
||||||
|
if use_t5xxl:
|
||||||
|
logger.info("Create t5xxl")
|
||||||
|
t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd)
|
||||||
|
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||||
|
logger.info(f"Loaded t5xxl: {info}")
|
||||||
|
|
||||||
|
logger.info(f"Move t5xxl to {device} and {sd3_dtype}...")
|
||||||
|
t5xxl.to(device, dtype=sd3_dtype)
|
||||||
|
# t5xxl.to("cpu", dtype=torch.float32) # run on CPU
|
||||||
|
t5xxl.eval()
|
||||||
|
logger.info(f"Set attn_mode to {args.attn_mode}...")
|
||||||
|
t5xxl.set_attn_mode(args.attn_mode)
|
||||||
|
else:
|
||||||
|
t5xxl = None
|
||||||
|
|
||||||
|
# prepare embeddings
|
||||||
|
logger.info("Encoding prompts...")
|
||||||
|
# embeds, pooled_embed
|
||||||
|
cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl)
|
||||||
|
neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl)
|
||||||
|
|
||||||
|
# generate image
|
||||||
|
logger.info("Generating image...")
|
||||||
|
latent_sampled = do_sample(
|
||||||
|
target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# latent to image
|
||||||
|
with torch.no_grad():
|
||||||
|
image = vae.decode(latent_sampled)
|
||||||
|
image = image.float()
|
||||||
|
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||||
|
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||||
|
decoded_np = decoded_np.astype(np.uint8)
|
||||||
|
out_image = Image.fromarray(decoded_np)
|
||||||
|
|
||||||
|
# save image
|
||||||
|
output_dir = args.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
||||||
|
out_image.save(output_path)
|
||||||
|
|
||||||
|
logger.info(f"Saved image to {output_path}")
|
||||||
Reference in New Issue
Block a user