mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
360 lines
13 KiB
Python
360 lines
13 KiB
Python
import argparse
|
|
import importlib
|
|
import math
|
|
import os
|
|
import random
|
|
import time
|
|
import numpy as np
|
|
|
|
from safetensors.torch import load_file, save_file
|
|
import torch
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
|
from PIL import Image
|
|
from accelerate import init_empty_weights
|
|
|
|
import library.stable_cascade as sc
|
|
import library.stable_cascade_utils as sc_utils
|
|
import library.device_utils as device_utils
|
|
from library import train_util
|
|
from library.sdxl_model_util import _load_state_dict_on_device
|
|
|
|
|
|
def main(args):
|
|
device = device_utils.get_preferred_device()
|
|
|
|
loading_device = device if not args.lowvram else "cpu"
|
|
text_model_device = "cpu"
|
|
|
|
dtype = torch.float32
|
|
if args.bf16:
|
|
dtype = torch.bfloat16
|
|
elif args.fp16:
|
|
dtype = torch.float16
|
|
|
|
text_model_dtype = torch.float32
|
|
|
|
# EfficientNet encoder
|
|
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
|
effnet.eval().requires_grad_(False).to(loading_device)
|
|
|
|
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
|
|
generator_c.eval().requires_grad_(False).to(loading_device)
|
|
|
|
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
|
|
generator_b.eval().requires_grad_(False).to(loading_device)
|
|
|
|
# CLIP encoders
|
|
tokenizer = sc_utils.load_tokenizer(args)
|
|
|
|
text_model = sc_utils.load_clip_text_model(
|
|
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
|
|
)
|
|
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
|
|
|
|
# image_model = (
|
|
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
|
|
# )
|
|
|
|
# vqGAN
|
|
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
|
|
stage_a.eval().requires_grad_(False)
|
|
|
|
# previewer
|
|
if args.previewer_checkpoint_path is not None:
|
|
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device)
|
|
previewer.eval().requires_grad_(False)
|
|
else:
|
|
previewer = None
|
|
|
|
# LoRA
|
|
if args.network_module:
|
|
for i, network_module in enumerate(args.network_module):
|
|
print("import network module:", network_module)
|
|
imported_module = importlib.import_module(network_module)
|
|
|
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
|
|
|
net_kwargs = {}
|
|
if args.network_args and i < len(args.network_args):
|
|
network_args = args.network_args[i]
|
|
# TODO escape special chars
|
|
network_args = network_args.split(";")
|
|
for net_arg in network_args:
|
|
key, value = net_arg.split("=")
|
|
net_kwargs[key] = value
|
|
|
|
if args.network_weights is None or len(args.network_weights) <= i:
|
|
raise ValueError("No weight. Weight is required.")
|
|
|
|
network_weight = args.network_weights[i]
|
|
print("load network weights from:", network_weight)
|
|
|
|
network, weights_sd = imported_module.create_network_from_weights(
|
|
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
|
|
)
|
|
if network is None:
|
|
return
|
|
|
|
mergeable = network.is_mergeable()
|
|
assert mergeable, "not-mergeable network is not supported yet."
|
|
|
|
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
|
|
|
|
# 謎のクラス gdf
|
|
gdf_c = sc.GDF(
|
|
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
|
input_scaler=sc.VPScaler(),
|
|
target=sc.EpsilonTarget(),
|
|
noise_cond=sc.CosineTNoiseCond(),
|
|
loss_weight=None,
|
|
)
|
|
gdf_b = sc.GDF(
|
|
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
|
input_scaler=sc.VPScaler(),
|
|
target=sc.EpsilonTarget(),
|
|
noise_cond=sc.CosineTNoiseCond(),
|
|
loss_weight=None,
|
|
)
|
|
|
|
# Stage C Parameters
|
|
|
|
# extras.sampling_configs["cfg"] = 4
|
|
# extras.sampling_configs["shift"] = 2
|
|
# extras.sampling_configs["timesteps"] = 20
|
|
# extras.sampling_configs["t_start"] = 1.0
|
|
|
|
# # Stage B Parameters
|
|
# extras_b.sampling_configs["cfg"] = 1.1
|
|
# extras_b.sampling_configs["shift"] = 1
|
|
# extras_b.sampling_configs["timesteps"] = 10
|
|
# extras_b.sampling_configs["t_start"] = 1.0
|
|
b_cfg = 1.1
|
|
b_shift = 1
|
|
b_timesteps = 10
|
|
b_t_start = 1.0
|
|
|
|
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
|
# height, width = 1024, 1024
|
|
|
|
while True:
|
|
print("type caption:")
|
|
# if Ctrl+Z is pressed, it will raise EOFError
|
|
try:
|
|
caption = input()
|
|
except EOFError:
|
|
break
|
|
|
|
caption = caption.strip()
|
|
if caption == "":
|
|
continue
|
|
|
|
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
|
|
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
|
|
|
|
tokens = caption.split()
|
|
width = height = 1024
|
|
cfg = 4
|
|
timesteps = 20
|
|
shift = 2
|
|
t_start = 1.0
|
|
negative_prompt = ""
|
|
seed = None
|
|
|
|
caption_tokens = []
|
|
i = 0
|
|
while i < len(tokens):
|
|
token = tokens[i]
|
|
if i == len(tokens) - 1:
|
|
caption_tokens.append(token)
|
|
i += 1
|
|
continue
|
|
|
|
if token == "--w":
|
|
width = int(tokens[i + 1])
|
|
elif token == "--h":
|
|
height = int(tokens[i + 1])
|
|
elif token == "--l":
|
|
cfg = float(tokens[i + 1])
|
|
elif token == "--s":
|
|
timesteps = int(tokens[i + 1])
|
|
elif token == "--f":
|
|
shift = float(tokens[i + 1])
|
|
elif token == "--t":
|
|
t_start = float(tokens[i + 1])
|
|
elif token == "--n":
|
|
negative_prompt = tokens[i + 1]
|
|
elif token == "--d":
|
|
seed = int(tokens[i + 1])
|
|
else:
|
|
caption_tokens.append(token)
|
|
i += 1
|
|
continue
|
|
|
|
i += 2
|
|
|
|
caption = " ".join(caption_tokens)
|
|
|
|
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
|
|
|
|
# PREPARE CONDITIONS
|
|
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
|
input_ids = tokenizer(
|
|
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
|
)["input_ids"].to(text_model.device)
|
|
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
|
|
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
|
)
|
|
cond_text = cond_text.to(device, dtype=dtype)
|
|
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
|
|
|
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
|
input_ids = tokenizer(
|
|
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
|
)["input_ids"].to(text_model.device)
|
|
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
|
|
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
|
)
|
|
uncond_text = uncond_text.to(device, dtype=dtype)
|
|
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
|
|
|
zero_img_emb = torch.zeros(1, 768, device=device)
|
|
|
|
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
|
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
|
unconditions = {
|
|
"clip_text_pooled": uncond_pooled,
|
|
"clip": uncond_pooled,
|
|
"clip_text": uncond_text,
|
|
"clip_img": zero_img_emb,
|
|
}
|
|
conditions_b = {}
|
|
conditions_b.update(conditions)
|
|
unconditions_b = {}
|
|
unconditions_b.update(unconditions)
|
|
|
|
# seed everything
|
|
if seed is not None:
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
# torch.backends.cudnn.deterministic = True
|
|
# torch.backends.cudnn.benchmark = False
|
|
|
|
if args.lowvram:
|
|
generator_c = generator_c.to(device)
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
|
sampling_c = gdf_c.sample(
|
|
generator_c,
|
|
conditions,
|
|
stage_c_latent_shape,
|
|
unconditions,
|
|
device=device,
|
|
cfg=cfg,
|
|
shift=shift,
|
|
timesteps=timesteps,
|
|
t_start=t_start,
|
|
)
|
|
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
|
|
sampled_c = sampled_c
|
|
|
|
conditions_b["effnet"] = sampled_c
|
|
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
|
|
|
|
if previewer is not None:
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
|
preview = previewer(sampled_c)
|
|
preview = preview.clamp(0, 1)
|
|
preview = preview.permute(0, 2, 3, 1).squeeze(0)
|
|
preview = preview.detach().float().cpu().numpy()
|
|
preview = Image.fromarray((preview * 255).astype(np.uint8))
|
|
|
|
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png"))
|
|
|
|
if args.lowvram:
|
|
generator_c = generator_c.to(loading_device)
|
|
device_utils.clean_memory_on_device(device)
|
|
generator_b = generator_b.to(device)
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
|
sampling_b = gdf_b.sample(
|
|
generator_b,
|
|
conditions_b,
|
|
stage_b_latent_shape,
|
|
unconditions_b,
|
|
device=device,
|
|
cfg=b_cfg,
|
|
shift=b_shift,
|
|
timesteps=b_timesteps,
|
|
t_start=b_t_start,
|
|
)
|
|
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
|
|
sampled_b = sampled_b
|
|
|
|
if args.lowvram:
|
|
generator_b = generator_b.to(loading_device)
|
|
device_utils.clean_memory_on_device(device)
|
|
stage_a = stage_a.to(device)
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
|
sampled = stage_a.decode(sampled_b).float()
|
|
# print(sampled.shape, sampled.min(), sampled.max())
|
|
|
|
if args.lowvram:
|
|
stage_a = stage_a.to(loading_device)
|
|
device_utils.clean_memory_on_device(device)
|
|
|
|
# float 0-1 to PIL Image
|
|
sampled = sampled.clamp(0, 1)
|
|
sampled = sampled.mul(255).to(dtype=torch.uint8)
|
|
sampled = sampled.permute(0, 2, 3, 1)
|
|
sampled = sampled.cpu().numpy()
|
|
sampled = Image.fromarray(sampled[0])
|
|
|
|
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
sc_utils.add_effnet_arguments(parser)
|
|
train_util.add_tokenizer_arguments(parser)
|
|
sc_utils.add_stage_a_arguments(parser)
|
|
sc_utils.add_stage_b_arguments(parser)
|
|
sc_utils.add_stage_c_arguments(parser)
|
|
sc_utils.add_previewer_arguments(parser)
|
|
sc_utils.add_text_model_arguments(parser)
|
|
parser.add_argument("--bf16", action="store_true")
|
|
parser.add_argument("--fp16", action="store_true")
|
|
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
|
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
|
|
parser.add_argument(
|
|
"--network_module",
|
|
type=str,
|
|
default=None,
|
|
nargs="*",
|
|
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
|
|
)
|
|
parser.add_argument(
|
|
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
|
|
)
|
|
parser.add_argument(
|
|
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
|
|
)
|
|
parser.add_argument(
|
|
"--network_args",
|
|
type=str,
|
|
default=None,
|
|
nargs="*",
|
|
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|