mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge 115b1d0aa7 into cadcd3169b
This commit is contained in:
3365
accel_sdxl_gen_img.py
Normal file
3365
accel_sdxl_gen_img.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -404,14 +404,14 @@ def train(args):
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
example_tuple = (latents, batch["captions"])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -474,7 +474,7 @@ def train(args):
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
|
||||
73
gen_img.py
73
gen_img.py
@@ -1485,6 +1485,7 @@ class ListPrompter:
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
if args.fp16:
|
||||
dtype = torch.float16
|
||||
elif args.bf16:
|
||||
@@ -1492,6 +1493,8 @@ def main(args):
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
device = get_preferred_device()
|
||||
|
||||
highres_fix = args.highres_fix_scale is not None
|
||||
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||
|
||||
@@ -1521,9 +1524,10 @@ def main(args):
|
||||
if is_sdxl:
|
||||
if args.clip_skip is None:
|
||||
args.clip_skip = 2
|
||||
|
||||
|
||||
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
|
||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype
|
||||
)
|
||||
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
@@ -1748,7 +1752,7 @@ def main(args):
|
||||
logger.info(f"network_merge: {network_merge}")
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
logger.info("import network module: {network_module}")
|
||||
logger.info(f"import network module: {network_module}")
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
@@ -2508,7 +2512,7 @@ def main(args):
|
||||
metadata.add_text("crop-left", str(crop_left))
|
||||
|
||||
if filename is not None:
|
||||
fln = filename
|
||||
fln = first_available_filename(args.outdir, filename) #Checks to make sure is not existing file, else returns first available sequential filename
|
||||
else:
|
||||
if args.use_original_file_name and init_images is not None:
|
||||
if type(init_images) is list:
|
||||
@@ -2586,7 +2590,8 @@ def main(args):
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seed = None
|
||||
seeds = None
|
||||
if pi == 0:
|
||||
seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
@@ -2670,7 +2675,11 @@ def main(args):
|
||||
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
if pi > 0 and len(raw_prompts) > 1: #Bypass od 2nd loop for dynamic prompts
|
||||
continue
|
||||
logger.info(f"{m}")
|
||||
seeds = m.group(1).split(",")
|
||||
seeds = [int(d.strip()) for d in seeds]
|
||||
logger.info(f"seeds: {seeds}")
|
||||
continue
|
||||
|
||||
@@ -2795,14 +2804,19 @@ def main(args):
|
||||
m = re.match(r"f (.+)", parg, re.IGNORECASE)
|
||||
if m: # filename
|
||||
filename = m.group(1)
|
||||
logger.info(f"filename: {filename}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
logger.error(f"Exception in parsing / 解析エラー: {parg}")
|
||||
logger.error(f"{ex}")
|
||||
|
||||
# override Deep Shrink
|
||||
# override filename to add index number if more than one image per prompt
|
||||
if filename is not None and (args.images_per_prompt > 1 or len(raw_prompts) > 1):
|
||||
fileext = os.path.splitext(filename)
|
||||
filename = fileext[0] + "_%s" % pi + fileext[1]
|
||||
logger.info(f"filename: {filename}")
|
||||
|
||||
# override Deep Shrink
|
||||
if ds_depth_1 is not None:
|
||||
if ds_depth_1 < 0:
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
@@ -2835,8 +2849,16 @@ def main(args):
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
|
||||
if len(seeds) > 0:
|
||||
# Previous implementation may result in unexpected behaviour when number of seeds is lesss than number of repeats. Last seed is taken for rest of repeated prompts. Add condition if last element is -1, to start randomizing seed.
|
||||
if len(seeds) > 1:
|
||||
seed = seeds.pop(0)
|
||||
elif len(seeds) == 1:
|
||||
if seeds[0] == -1:
|
||||
seeds = None
|
||||
else:
|
||||
seed = seeds.pop(0)
|
||||
|
||||
|
||||
else:
|
||||
if args.iter_same_seed:
|
||||
seed = iter_seed
|
||||
@@ -2847,6 +2869,7 @@ def main(args):
|
||||
seed = seed_random.randint(0, 2**32 - 1)
|
||||
if args.interactive:
|
||||
logger.info(f"seed: {seed}")
|
||||
# logger.info(f"seed: {seed}") #debugging logger. Uncomment to verify if expected seed is added correctly.
|
||||
|
||||
# prepare init image, guide image and mask
|
||||
init_image = mask_image = guide_image = None
|
||||
@@ -2935,7 +2958,35 @@ def main(args):
|
||||
|
||||
logger.info("done!")
|
||||
|
||||
def first_available_filename(path, filename):
|
||||
"""
|
||||
Checks if filename is in use.
|
||||
if filename is in use, appends a running number
|
||||
e.g. filename = 'file.png':
|
||||
|
||||
file.png
|
||||
file_1.png
|
||||
file_2.png
|
||||
|
||||
Runs in log(n) time where n is the number of existing files in sequence
|
||||
"""
|
||||
i = 1
|
||||
if not os.path.exists(os.path.join(path, filename)):
|
||||
return filename
|
||||
fileext = os.path.splitext(filename)
|
||||
filename = fileext[0] + "_%s" + fileext[1]
|
||||
# First do an exponential search
|
||||
while os.path.exists(os.path.join(path,filename % i)):
|
||||
i = i * 2
|
||||
|
||||
# Result lies somewhere in the interval (i/2..i]
|
||||
# We call this interval (a..b] and narrow it down until a + 1 = b
|
||||
a, b = (i // 2, i)
|
||||
while a + 1 < b:
|
||||
c = (a + b) // 2 # interval midpoint
|
||||
a, b = (c, b) if os.path.exists(os.path.join(path,filename % c)) else (a, c)
|
||||
|
||||
return filename % b
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -3343,6 +3394,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
|
||||
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
||||
)
|
||||
parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16")
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="Loading model in bf16"
|
||||
)
|
||||
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
|
||||
@@ -91,10 +91,10 @@ def _load_target_model(
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
|
||||
)
|
||||
except EnvironmentError as ex:
|
||||
except ValueError as ex:
|
||||
if variant is not None:
|
||||
logger.info("try to load fp32 model")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
||||
logger.info("try to load default model")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=model_dtype, variant=None, tokenizer=None)
|
||||
else:
|
||||
raise ex
|
||||
except EnvironmentError as ex:
|
||||
|
||||
@@ -28,6 +28,7 @@ import random
|
||||
import hashlib
|
||||
import subprocess
|
||||
from io import BytesIO
|
||||
from accelerate.utils import gather_object
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -1724,7 +1725,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
reg_infos.append((info, subset))
|
||||
if subset.num_repeats > 1:
|
||||
info.num_repeats = 1
|
||||
for i in range(subset.num_repeats):
|
||||
reg_infos.append((info, subset))
|
||||
else:
|
||||
self.register_image(info, subset)
|
||||
|
||||
@@ -1735,6 +1739,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
logger.info(f"{num_reg_images} reg images.")
|
||||
random.shuffle(reg_infos)
|
||||
if num_train_images < num_reg_images:
|
||||
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
||||
|
||||
@@ -1748,13 +1753,16 @@ class DreamBoothDataset(BaseDataset):
|
||||
for info, subset in reg_infos:
|
||||
if first_loop:
|
||||
self.register_image(info, subset)
|
||||
logger.info(f"Registering image: {info.absolute_path}")
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1 # rewrite registered info
|
||||
logger.info(f"Registering image: {info.absolute_path}")
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
random.shuffle(reg_infos)
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
@@ -4720,7 +4728,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index+1}/{accelerator.state.num_processes}")
|
||||
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||
args,
|
||||
@@ -5426,7 +5434,7 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
|
||||
return prompt_dict
|
||||
|
||||
|
||||
RE_CAPTION_PROMPT = re.compile(r"(?i)__caption((\|)(.+?)?)?__")
|
||||
def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator: Accelerator,
|
||||
@@ -5438,6 +5446,7 @@ def sample_images_common(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
example_tuple=None,
|
||||
prompt_replacement=None,
|
||||
controlnet=None,
|
||||
):
|
||||
@@ -5468,7 +5477,7 @@ def sample_images_common(
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
vae.to(device) # distributed_state.device is same as accelerator.device
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
@@ -5478,18 +5487,23 @@ def sample_images_common(
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# read prompts
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
|
||||
if distributed_state.is_main_process:
|
||||
# Load prompts into prompts list on main process only
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
else:
|
||||
prompts = [] # Init empty prompts list for sub processes.
|
||||
|
||||
# schedulers: dict = {} cannot find where this is used
|
||||
default_scheduler = get_my_scheduler(
|
||||
sample_sampler=args.sample_sampler,
|
||||
@@ -5507,21 +5521,65 @@ def sample_images_common(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.to(distributed_state.device)
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
pipeline.to(device)
|
||||
|
||||
|
||||
# preprocess prompts
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
if example_tuple:
|
||||
latents_list = []
|
||||
for idx in range(len(example_tuple[1])):
|
||||
latent_dict = {}
|
||||
latent_dict["prompt"] = example_tuple[1][idx]
|
||||
latent_dict["height"] = example_tuple[0].shape[2] * 8
|
||||
latent_dict["width"] = example_tuple[0].shape[3] * 8
|
||||
latent_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
|
||||
latents_list.append(latent_dict)
|
||||
distributed_state.wait_for_everyone()
|
||||
latents_list = gather_object(latents_list)
|
||||
save_dir = args.output_dir + "/sample"
|
||||
if distributed_state.is_main_process:
|
||||
#Create output folder and preprocess prompts on main process only.
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
idx = 0
|
||||
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
selected = ""
|
||||
if '__caption' in prompt_dict.get("prompt"):
|
||||
match_caption = RE_CAPTION_PROMPT.search(prompt_dict.get("prompt"))
|
||||
if match_caption is not None:
|
||||
if not example_tuple:
|
||||
|
||||
if match_caption.group(3) is not None:
|
||||
caption_list = match_caption.group(3).split("|")
|
||||
selected = random.choice(caption_list)
|
||||
prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else f'Astronaut riding a horse on the moon')
|
||||
logger.info(f"Backup prompt: {prompt_dict.get('prompt')}")
|
||||
else:
|
||||
while latents_list[idx]["prompt"] == '':
|
||||
idx = (idx + 1) % len(latents_list)
|
||||
if idx == 0:
|
||||
break
|
||||
prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), f'{latents_list[idx]["prompt"]}')
|
||||
#logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}")
|
||||
prompt_dict["height"] = latents_list[idx]["height"]
|
||||
#logger.info(f"Original Image Height: {prompt_dict['height']}")
|
||||
prompt_dict["width"] = latents_list[idx]["width"]
|
||||
#logger.info(f"Original Image Width: {prompt_dict['width']}")
|
||||
prompt_dict["original_lantent"] = latents_list[idx]["original_lantent"]
|
||||
idx = (idx + 1) % len(latents_list)
|
||||
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
logger.info(f"Current prompt: {prompts[i].get('prompt')}")
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
|
||||
# save random state to restore later
|
||||
rng_state = torch.get_rng_state()
|
||||
@@ -5531,26 +5589,25 @@ def sample_images_common(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
if distributed_state.num_processes > 1 and distributed_state.is_main_process:
|
||||
per_process_prompts = [] # list of lists
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
prompts = []
|
||||
# Flattening prompts for simplicity
|
||||
for prompt in per_process_prompts:
|
||||
prompts.extend(prompt)
|
||||
distributed_state.wait_for_everyone()
|
||||
prompts = gather_object(prompts)
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists:
|
||||
sample_image_inference(
|
||||
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
|
||||
)
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
@@ -5565,6 +5622,42 @@ def sample_images_common(
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
def draw_text_on_image(text, max_width, text_color="black"):
|
||||
from PIL import ImageDraw, ImageFont, Image
|
||||
import textwrap
|
||||
|
||||
font = ImageFont.load_default()
|
||||
space_width = font.getbbox(' ')[2]
|
||||
font_size = 20
|
||||
|
||||
def wrap_text(text, font, max_width):
|
||||
words = text.split(' ')
|
||||
lines = []
|
||||
current_line = ""
|
||||
for word in words:
|
||||
test_line = current_line + word + " "
|
||||
if font.getbbox(test_line)[2] <= max_width:
|
||||
current_line = test_line
|
||||
else:
|
||||
lines.append(current_line)
|
||||
current_line = word + " "
|
||||
lines.append(current_line)
|
||||
return lines
|
||||
|
||||
lines = wrap_text(text, font, max_width - 10)
|
||||
text_height = sum([font.getbbox(line)[3] - font.getbbox(line)[1] for line in lines]) + 20
|
||||
text_image = Image.new('RGB', (max_width, text_height), 'white')
|
||||
text_draw = ImageDraw.Draw(text_image)
|
||||
|
||||
y_text = 10
|
||||
for line in lines:
|
||||
bbox = text_draw.textbbox((0, 0), line, font=font)
|
||||
height = bbox[3] - bbox[1]
|
||||
text_draw.text((10, y_text), line, font=font, fill=text_color)
|
||||
y_text += font_size
|
||||
|
||||
return text_image
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
@@ -5635,13 +5728,23 @@ def sample_image_inference(
|
||||
controlnet=controlnet,
|
||||
controlnet_image=controlnet_image,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(torch.cuda.current_device()):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
image = pipeline.latents_to_image(latents)[0]
|
||||
|
||||
if "original_lantent" in prompt_dict:
|
||||
#Prevent out of VRAM error
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
original_latent = prompt_dict.get("original_lantent").to(device=accelerator.device)
|
||||
original_image = pipeline.latents_to_image(original_latent)[0]
|
||||
text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2)
|
||||
new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height))
|
||||
new_image.paste(original_image, (0, text_image.height))
|
||||
new_image.paste(image, (original_image.width, text_image.height))
|
||||
new_image.paste(text_image, (0, 0))
|
||||
image = new_image
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
|
||||
if handler is None:
|
||||
handler = logging.StreamHandler(sys.stdout) # same as print
|
||||
handler.propagate = False
|
||||
handler.propagate = True
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(message)s",
|
||||
|
||||
@@ -13,6 +13,7 @@ import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import gc
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
@@ -1489,12 +1490,16 @@ def main(args):
|
||||
files = glob.glob(args.ckpt)
|
||||
if len(files) == 1:
|
||||
args.ckpt = files[0]
|
||||
|
||||
device = get_preferred_device()
|
||||
logger.info(f"preferred device: {device}")
|
||||
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
|
||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype
|
||||
)
|
||||
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||
|
||||
text_encoder1.to(dtype).to(device)
|
||||
text_encoder2.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
# xformers、Hypernetwork対応
|
||||
if not args.diffusers_xformers:
|
||||
mem_eff = not (args.xformers or args.sdpa)
|
||||
@@ -1621,7 +1626,7 @@ def main(args):
|
||||
# scheduler.config.clip_sample = True
|
||||
|
||||
# deviceを決定する
|
||||
device = get_preferred_device()
|
||||
|
||||
|
||||
# custom pipelineをコピったやつを生成する
|
||||
if args.vae_slices:
|
||||
@@ -1651,13 +1656,9 @@ def main(args):
|
||||
vae.to(vae_dtype).to(device)
|
||||
vae.eval()
|
||||
|
||||
text_encoder1.to(dtype).to(device)
|
||||
text_encoder2.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
unet.eval()
|
||||
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
networks = []
|
||||
@@ -2809,7 +2810,9 @@ def main(args):
|
||||
num_sub_prompts,
|
||||
),
|
||||
)
|
||||
|
||||
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
||||
logger.info("Does this run? When number of prompts less than batch?")
|
||||
process_batch(batch_data, highres_fix)
|
||||
batch_data.clear()
|
||||
|
||||
@@ -2817,8 +2820,9 @@ def main(args):
|
||||
if len(batch_data) == args.batch_size:
|
||||
prev_image = process_batch(batch_data, highres_fix)[0]
|
||||
batch_data.clear()
|
||||
|
||||
logger.info(f"Global Step: {global_step}")
|
||||
global_step += 1
|
||||
|
||||
|
||||
prompt_index += 1
|
||||
|
||||
@@ -3194,6 +3198,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
|
||||
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
|
||||
)
|
||||
parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16")
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="Loading model in bf16"
|
||||
)
|
||||
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
|
||||
@@ -740,6 +740,7 @@ def train(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
|
||||
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
@@ -757,6 +758,8 @@ def train(args):
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
|
||||
example_tuple = (latents, batch["captions"])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -772,6 +775,7 @@ def train(args):
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
example_tuple,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -854,6 +858,7 @@ def train(args):
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
example_tuple,
|
||||
)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
@@ -163,8 +163,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
|
||||
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -81,9 +81,9 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||
|
||||
@@ -388,14 +388,14 @@ def train(args):
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
example_tuple = (latents, batch["captions"])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -459,7 +459,7 @@ def train(args):
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
|
||||
@@ -39,6 +39,7 @@ from library.custom_train_functions import (
|
||||
apply_masked_loss,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
from accelerate.utils import gather_object, gather
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -131,8 +132,8 @@ class NetworkTrainer:
|
||||
if param.grad is not None:
|
||||
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1030,11 +1031,15 @@ class NetworkTrainer:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
example_tuple = (latents.detach().clone(), batch["captions"])
|
||||
if args.sample_every_n_steps is not None and global_step % args.sample_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -1090,7 +1095,9 @@ class NetworkTrainer:
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
# end of epoch
|
||||
|
||||
@@ -1233,6 +1240,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
||||
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
||||
)
|
||||
|
||||
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
||||
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
||||
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
||||
|
||||
@@ -122,9 +122,9 @@ class TextualInversionTrainer:
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||
@@ -627,6 +627,7 @@ class TextualInversionTrainer:
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
example_tuple = (latents, captions)
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -642,6 +643,7 @@ class TextualInversionTrainer:
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
example_tuple,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
@@ -714,7 +716,6 @@ class TextualInversionTrainer:
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
@@ -725,6 +726,7 @@ class TextualInversionTrainer:
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
example_tuple,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user