mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
200 lines
7.4 KiB
Python
200 lines
7.4 KiB
Python
import argparse
|
|
import math
|
|
import os
|
|
import time
|
|
|
|
from safetensors.torch import load_file, save_file
|
|
import torch
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
|
from PIL import Image
|
|
from accelerate import init_empty_weights
|
|
|
|
import library.stable_cascade as sc
|
|
import library.stable_cascade_utils as sc_utils
|
|
import library.device_utils as device_utils
|
|
from library import train_util
|
|
from library.sdxl_model_util import _load_state_dict_on_device
|
|
|
|
|
|
def main(args):
|
|
device = device_utils.get_preferred_device()
|
|
|
|
loading_device = device if not args.lowvram else "cpu"
|
|
text_model_device = "cpu"
|
|
|
|
dtype = torch.float32
|
|
if args.bf16:
|
|
dtype = torch.bfloat16
|
|
elif args.fp16:
|
|
dtype = torch.float16
|
|
|
|
text_model_dtype = torch.float32
|
|
|
|
# EfficientNet encoder
|
|
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
|
effnet.eval().requires_grad_(False).to(loading_device)
|
|
|
|
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
|
|
generator_c.eval().requires_grad_(False).to(loading_device)
|
|
|
|
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
|
|
generator_b.eval().requires_grad_(False).to(loading_device)
|
|
|
|
# CLIP encoders
|
|
tokenizer = sc_utils.load_tokenizer(args)
|
|
|
|
text_model = sc_utils.load_clip_text_model(
|
|
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
|
|
)
|
|
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
|
|
|
|
# image_model = (
|
|
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
|
|
# )
|
|
|
|
# vqGAN
|
|
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
|
|
stage_a.eval().requires_grad_(False)
|
|
|
|
caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
|
height, width = 1024, 1024
|
|
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
|
|
|
|
# 謎のクラス gdf
|
|
gdf_c = sc.GDF(
|
|
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
|
input_scaler=sc.VPScaler(),
|
|
target=sc.EpsilonTarget(),
|
|
noise_cond=sc.CosineTNoiseCond(),
|
|
loss_weight=None,
|
|
)
|
|
gdf_b = sc.GDF(
|
|
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
|
input_scaler=sc.VPScaler(),
|
|
target=sc.EpsilonTarget(),
|
|
noise_cond=sc.CosineTNoiseCond(),
|
|
loss_weight=None,
|
|
)
|
|
|
|
# Stage C Parameters
|
|
|
|
# extras.sampling_configs["cfg"] = 4
|
|
# extras.sampling_configs["shift"] = 2
|
|
# extras.sampling_configs["timesteps"] = 20
|
|
# extras.sampling_configs["t_start"] = 1.0
|
|
|
|
# # Stage B Parameters
|
|
# extras_b.sampling_configs["cfg"] = 1.1
|
|
# extras_b.sampling_configs["shift"] = 1
|
|
# extras_b.sampling_configs["timesteps"] = 10
|
|
# extras_b.sampling_configs["t_start"] = 1.0
|
|
|
|
# PREPARE CONDITIONS
|
|
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
|
input_ids = tokenizer(
|
|
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
|
)["input_ids"].to(text_model.device)
|
|
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
|
|
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
|
)
|
|
cond_text = cond_text.to(device, dtype=dtype)
|
|
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
|
|
|
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
|
input_ids = tokenizer([""], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")[
|
|
"input_ids"
|
|
].to(text_model.device)
|
|
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
|
|
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
|
)
|
|
uncond_text = uncond_text.to(device, dtype=dtype)
|
|
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
|
|
|
zero_img_emb = torch.zeros(1, 768, device=device)
|
|
|
|
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
|
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
|
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
|
|
conditions_b = {}
|
|
conditions_b.update(conditions)
|
|
unconditions_b = {}
|
|
unconditions_b.update(unconditions)
|
|
|
|
# torch.manual_seed(42)
|
|
|
|
if args.lowvram:
|
|
generator_c = generator_c.to(device)
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
|
sampling_c = gdf_c.sample(
|
|
generator_c, conditions, stage_c_latent_shape, unconditions, device=device, cfg=4, shift=2, timesteps=20, t_start=1.0
|
|
)
|
|
for sampled_c, _, _ in tqdm(sampling_c, total=20):
|
|
sampled_c = sampled_c
|
|
|
|
conditions_b["effnet"] = sampled_c
|
|
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
|
|
|
|
if args.lowvram:
|
|
generator_c = generator_c.to(loading_device)
|
|
device_utils.clean_memory_on_device(device)
|
|
generator_b = generator_b.to(device)
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
|
sampling_b = gdf_b.sample(
|
|
generator_b,
|
|
conditions_b,
|
|
stage_b_latent_shape,
|
|
unconditions_b,
|
|
device=device,
|
|
cfg=1.1,
|
|
shift=1,
|
|
timesteps=10,
|
|
t_start=1.0,
|
|
)
|
|
for sampled_b, _, _ in tqdm(sampling_b, total=10):
|
|
sampled_b = sampled_b
|
|
|
|
if args.lowvram:
|
|
generator_b = generator_b.to(loading_device)
|
|
device_utils.clean_memory_on_device(device)
|
|
stage_a = stage_a.to(device)
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
|
sampled = stage_a.decode(sampled_b).float()
|
|
print(sampled.shape, sampled.min(), sampled.max())
|
|
|
|
if args.lowvram:
|
|
stage_a = stage_a.to(loading_device)
|
|
device_utils.clean_memory_on_device(device)
|
|
|
|
# float 0-1 to PIL Image
|
|
sampled = sampled.clamp(0, 1)
|
|
sampled = sampled.mul(255).to(dtype=torch.uint8)
|
|
sampled = sampled.permute(0, 2, 3, 1)
|
|
sampled = sampled.cpu().numpy()
|
|
sampled = Image.fromarray(sampled[0])
|
|
|
|
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
sc_utils.add_effnet_arguments(parser)
|
|
train_util.add_tokenizer_arguments(parser)
|
|
sc_utils.add_stage_a_arguments(parser)
|
|
sc_utils.add_stage_b_arguments(parser)
|
|
sc_utils.add_stage_c_arguments(parser)
|
|
sc_utils.add_text_model_arguments(parser)
|
|
parser.add_argument("--bf16", action="store_true")
|
|
parser.add_argument("--fp16", action="store_true")
|
|
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
|
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|