mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor SD3 CLIP to transformers etc.
This commit is contained in:
@@ -12,6 +12,7 @@ import torch
|
||||
from safetensors.torch import safe_open, load_file
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
@@ -25,11 +26,14 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sd3_models, sd3_utils, strategy_sd3
|
||||
from library.utils import load_safetensors
|
||||
|
||||
|
||||
def get_noise(seed, latent):
|
||||
generator = torch.manual_seed(seed)
|
||||
return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype)
|
||||
def get_noise(seed, latent, device="cpu"):
|
||||
# generator = torch.manual_seed(seed)
|
||||
generator = torch.Generator(device)
|
||||
generator.manual_seed(seed)
|
||||
return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device)
|
||||
|
||||
|
||||
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||
@@ -59,7 +63,7 @@ def do_sample(
|
||||
neg_cond: Tuple[torch.Tensor, torch.Tensor],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
cfg_scale: float,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
@@ -71,7 +75,7 @@ def do_sample(
|
||||
|
||||
latent = latent.to(dtype).to(device)
|
||||
|
||||
noise = get_noise(seed, latent).to(device)
|
||||
noise = get_noise(seed, latent, device)
|
||||
|
||||
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
|
||||
|
||||
@@ -105,7 +109,7 @@ def do_sample(
|
||||
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
||||
|
||||
pos_out, neg_out = batched.chunk(2)
|
||||
denoised = neg_out + (pos_out - neg_out) * guidance_scale
|
||||
denoised = neg_out + (pos_out - neg_out) * cfg_scale
|
||||
# print(denoised.shape)
|
||||
|
||||
# d = to_d(x, sigma_hat, denoised)
|
||||
@@ -122,230 +126,68 @@ def do_sample(
|
||||
x = x.to(dtype)
|
||||
|
||||
latent = x
|
||||
scale_factor = 1.5305
|
||||
shift_factor = 0.0609
|
||||
# def process_out(self, latent):
|
||||
# return (latent / self.scale_factor) + self.shift_factor
|
||||
latent = (latent / scale_factor) + shift_factor
|
||||
latent = vae.process_out(latent)
|
||||
return latent
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
target_height = 1024
|
||||
target_width = 1024
|
||||
|
||||
# steps = 50 # 28 # 50
|
||||
guidance_scale = 5
|
||||
# seed = 1 # None # 1
|
||||
|
||||
device = get_preferred_device()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--clip_g", type=str, required=False)
|
||||
parser.add_argument("--clip_l", type=str, required=False)
|
||||
parser.add_argument("--t5xxl", type=str, required=False)
|
||||
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
|
||||
parser.add_argument("--apply_lg_attn_mask", action="store_true")
|
||||
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
parser.add_argument("--output_dir", type=str, default=".")
|
||||
parser.add_argument("--do_not_use_t5xxl", action="store_true")
|
||||
parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--bf16", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
# parser.add_argument(
|
||||
# "--lora_weights",
|
||||
# type=str,
|
||||
# nargs="*",
|
||||
# default=[],
|
||||
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
# )
|
||||
# parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
steps = args.steps
|
||||
|
||||
sd3_dtype = torch.float32
|
||||
if args.fp16:
|
||||
sd3_dtype = torch.float16
|
||||
elif args.bf16:
|
||||
sd3_dtype = torch.bfloat16
|
||||
|
||||
# TODO test with separated safetenors files for each model
|
||||
|
||||
# load state dict
|
||||
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
|
||||
state_dict = load_file(args.ckpt_path)
|
||||
|
||||
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||
logger.info("clip_g is included in the checkpoint")
|
||||
clip_g_sd = {}
|
||||
prefix = "text_encoders.clip_g."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(prefix):
|
||||
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
else:
|
||||
logger.info(f"Lodaing clip_g from {args.clip_g}...")
|
||||
clip_g_sd = load_file(args.clip_g)
|
||||
for key in list(clip_g_sd.keys()):
|
||||
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||
|
||||
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||
logger.info("clip_l is included in the checkpoint")
|
||||
clip_l_sd = {}
|
||||
prefix = "text_encoders.clip_l."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(prefix):
|
||||
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
else:
|
||||
logger.info(f"Lodaing clip_l from {args.clip_l}...")
|
||||
clip_l_sd = load_file(args.clip_l)
|
||||
for key in list(clip_l_sd.keys()):
|
||||
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||
logger.info("t5xxl is included in the checkpoint")
|
||||
if not args.do_not_use_t5xxl:
|
||||
t5xxl_sd = {}
|
||||
prefix = "text_encoders.t5xxl."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(prefix):
|
||||
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
else:
|
||||
logger.info("but not used")
|
||||
for key in list(state_dict.keys()):
|
||||
if key.startswith("text_encoders.t5xxl."):
|
||||
state_dict.pop(key)
|
||||
t5xxl_sd = None
|
||||
elif args.t5xxl:
|
||||
assert not args.do_not_use_t5xxl, "t5xxl is not used but specified"
|
||||
logger.info(f"Lodaing t5xxl from {args.t5xxl}...")
|
||||
t5xxl_sd = load_file(args.t5xxl)
|
||||
for key in list(t5xxl_sd.keys()):
|
||||
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||
else:
|
||||
logger.info("t5xxl is not used")
|
||||
t5xxl_sd = None
|
||||
|
||||
use_t5xxl = t5xxl_sd is not None
|
||||
|
||||
# MMDiT and VAE
|
||||
vae_sd = {}
|
||||
vae_prefix = "first_stage_model."
|
||||
mmdit_prefix = "model.diffusion_model."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(vae_prefix):
|
||||
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||
elif k.startswith(mmdit_prefix):
|
||||
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# load tokenizers
|
||||
logger.info("Loading tokenizers...")
|
||||
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
|
||||
|
||||
# load models
|
||||
# logger.info("Create MMDiT from SD3 checkpoint...")
|
||||
# mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict)
|
||||
logger.info("Create MMDiT")
|
||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = mmdit.load_state_dict(state_dict)
|
||||
logger.info(f"Loaded MMDiT: {info}")
|
||||
|
||||
logger.info(f"Move MMDiT to {device} and {sd3_dtype}...")
|
||||
mmdit.to(device, dtype=sd3_dtype)
|
||||
mmdit.eval()
|
||||
|
||||
# load VAE
|
||||
logger.info("Create VAE")
|
||||
vae = sd3_models.SDVAE()
|
||||
logger.info("Loading state dict...")
|
||||
info = vae.load_state_dict(vae_sd)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
|
||||
logger.info(f"Move VAE to {device} and {sd3_dtype}...")
|
||||
vae.to(device, dtype=sd3_dtype)
|
||||
vae.eval()
|
||||
|
||||
# load text encoders
|
||||
logger.info("Create clip_l")
|
||||
clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_l.load_state_dict(clip_l_sd)
|
||||
logger.info(f"Loaded clip_l: {info}")
|
||||
|
||||
logger.info(f"Move clip_l to {device} and {sd3_dtype}...")
|
||||
clip_l.to(device, dtype=sd3_dtype)
|
||||
clip_l.eval()
|
||||
logger.info(f"Set attn_mode to {args.attn_mode}...")
|
||||
clip_l.set_attn_mode(args.attn_mode)
|
||||
|
||||
logger.info("Create clip_g")
|
||||
clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_g.load_state_dict(clip_g_sd)
|
||||
logger.info(f"Loaded clip_g: {info}")
|
||||
|
||||
logger.info(f"Move clip_g to {device} and {sd3_dtype}...")
|
||||
clip_g.to(device, dtype=sd3_dtype)
|
||||
clip_g.eval()
|
||||
logger.info(f"Set attn_mode to {args.attn_mode}...")
|
||||
clip_g.set_attn_mode(args.attn_mode)
|
||||
|
||||
if use_t5xxl:
|
||||
logger.info("Create t5xxl")
|
||||
t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||
logger.info(f"Loaded t5xxl: {info}")
|
||||
|
||||
logger.info(f"Move t5xxl to {device} and {sd3_dtype}...")
|
||||
t5xxl.to(device, dtype=sd3_dtype)
|
||||
# t5xxl.to("cpu", dtype=torch.float32) # run on CPU
|
||||
t5xxl.eval()
|
||||
logger.info(f"Set attn_mode to {args.attn_mode}...")
|
||||
t5xxl.set_attn_mode(args.attn_mode)
|
||||
else:
|
||||
t5xxl = None
|
||||
|
||||
def generate_image(
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
clip_l: CLIPTextModelWithProjection,
|
||||
clip_g: CLIPTextModelWithProjection,
|
||||
t5xxl: T5EncoderModel,
|
||||
steps: int,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
target_width: int,
|
||||
target_height: int,
|
||||
device: str,
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
):
|
||||
# prepare embeddings
|
||||
logger.info("Encoding prompts...")
|
||||
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(args.prompt)
|
||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
# TODO support one-by-one offloading
|
||||
clip_l.to(device)
|
||||
clip_g.to(device)
|
||||
t5xxl.to(device)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt)
|
||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
with torch.no_grad():
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
||||
lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# attn masks are not used currently
|
||||
|
||||
if args.offload:
|
||||
clip_l.to("cpu")
|
||||
clip_g.to("cpu")
|
||||
t5xxl.to("cpu")
|
||||
|
||||
# generate image
|
||||
logger.info("Generating image...")
|
||||
latent_sampled = do_sample(
|
||||
target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device
|
||||
)
|
||||
mmdit.to(device)
|
||||
latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device)
|
||||
if args.offload:
|
||||
mmdit.to("cpu")
|
||||
|
||||
# latent to image
|
||||
vae.to(device)
|
||||
with torch.no_grad():
|
||||
image = vae.decode(latent_sampled)
|
||||
|
||||
if args.offload:
|
||||
vae.to("cpu")
|
||||
|
||||
image = image.float()
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||
@@ -359,3 +201,179 @@ if __name__ == "__main__":
|
||||
out_image.save(output_path)
|
||||
|
||||
logger.info(f"Saved image to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
target_height = 1024
|
||||
target_width = 1024
|
||||
|
||||
# steps = 50 # 28 # 50
|
||||
# cfg_scale = 5
|
||||
# seed = 1 # None # 1
|
||||
|
||||
device = get_preferred_device()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--clip_g", type=str, required=False)
|
||||
parser.add_argument("--clip_l", type=str, required=False)
|
||||
parser.add_argument("--t5xxl", type=str, required=False)
|
||||
parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256")
|
||||
parser.add_argument("--apply_lg_attn_mask", action="store_true")
|
||||
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
parser.add_argument("--cfg_scale", type=float, default=5.0)
|
||||
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
|
||||
parser.add_argument("--output_dir", type=str, default=".")
|
||||
# parser.add_argument("--do_not_use_t5xxl", action="store_true")
|
||||
# parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--bf16", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
# parser.add_argument(
|
||||
# "--lora_weights",
|
||||
# type=str,
|
||||
# nargs="*",
|
||||
# default=[],
|
||||
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
# )
|
||||
parser.add_argument("--width", type=int, default=target_width)
|
||||
parser.add_argument("--height", type=int, default=target_height)
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
seed = args.seed
|
||||
steps = args.steps
|
||||
|
||||
sd3_dtype = torch.float32
|
||||
if args.fp16:
|
||||
sd3_dtype = torch.float16
|
||||
elif args.bf16:
|
||||
sd3_dtype = torch.bfloat16
|
||||
|
||||
loading_device = "cpu" if args.offload else device
|
||||
|
||||
# load state dict
|
||||
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
|
||||
# state_dict = load_file(args.ckpt_path)
|
||||
state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype)
|
||||
|
||||
# load text encoders
|
||||
clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
|
||||
# MMDiT and VAE
|
||||
vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict)
|
||||
mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device)
|
||||
|
||||
clip_l.to(sd3_dtype)
|
||||
clip_g.to(sd3_dtype)
|
||||
t5xxl.to(sd3_dtype)
|
||||
vae.to(sd3_dtype)
|
||||
mmdit.to(sd3_dtype)
|
||||
if not args.offload:
|
||||
# make sure to move to the device: some tensors are created in the constructor on the CPU
|
||||
clip_l.to(device)
|
||||
clip_g.to(device)
|
||||
t5xxl.to(device)
|
||||
vae.to(device)
|
||||
mmdit.to(device)
|
||||
|
||||
clip_l.eval()
|
||||
clip_g.eval()
|
||||
t5xxl.eval()
|
||||
mmdit.eval()
|
||||
vae.eval()
|
||||
|
||||
# load tokenizers
|
||||
logger.info("Loading tokenizers...")
|
||||
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
|
||||
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(
|
||||
mmdit,
|
||||
vae,
|
||||
clip_l,
|
||||
clip_g,
|
||||
t5xxl,
|
||||
args.steps,
|
||||
args.prompt,
|
||||
args.seed,
|
||||
args.width,
|
||||
args.height,
|
||||
device,
|
||||
args.negative_prompt,
|
||||
args.cfg_scale,
|
||||
)
|
||||
else:
|
||||
# loop for interactive
|
||||
width = args.width
|
||||
height = args.height
|
||||
steps = None
|
||||
cfg_scale = args.cfg_scale
|
||||
|
||||
while True:
|
||||
print(
|
||||
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed>"
|
||||
" --n <negative prompt>, `--n -` for empty negative prompt"
|
||||
"Options are kept for the next prompt. Current options:"
|
||||
f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}"
|
||||
)
|
||||
prompt = input()
|
||||
if prompt == "":
|
||||
break
|
||||
|
||||
# parse options
|
||||
options = prompt.split("--")
|
||||
prompt = options[0].strip()
|
||||
seed = None
|
||||
negative_prompt = None
|
||||
for opt in options[1:]:
|
||||
try:
|
||||
opt = opt.strip()
|
||||
if opt.startswith("w"):
|
||||
width = int(opt[1:].strip())
|
||||
elif opt.startswith("h"):
|
||||
height = int(opt[1:].strip())
|
||||
elif opt.startswith("s"):
|
||||
steps = int(opt[1:].strip())
|
||||
elif opt.startswith("d"):
|
||||
seed = int(opt[1:].strip())
|
||||
# elif opt.startswith("m"):
|
||||
# mutipliers = opt[1:].strip().split(",")
|
||||
# if len(mutipliers) != len(lora_models):
|
||||
# logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||
# continue
|
||||
# for i, lora_model in enumerate(lora_models):
|
||||
# lora_model.set_multiplier(float(mutipliers[i]))
|
||||
elif opt.startswith("n"):
|
||||
negative_prompt = opt[1:].strip()
|
||||
if negative_prompt == "-":
|
||||
negative_prompt = ""
|
||||
elif opt.startswith("c"):
|
||||
cfg_scale = float(opt[1:].strip())
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid option: {opt}, {e}")
|
||||
|
||||
generate_image(
|
||||
mmdit,
|
||||
vae,
|
||||
clip_l,
|
||||
clip_g,
|
||||
t5xxl,
|
||||
steps if steps is not None else args.steps,
|
||||
prompt,
|
||||
seed if seed is not None else args.seed,
|
||||
width,
|
||||
height,
|
||||
device,
|
||||
negative_prompt if negative_prompt is not None else args.negative_prompt,
|
||||
cfg_scale,
|
||||
)
|
||||
|
||||
logger.info("Done!")
|
||||
|
||||
Reference in New Issue
Block a user