This commit is contained in:
DKnight54
2025-09-30 09:54:28 +05:30
committed by GitHub
13 changed files with 3636 additions and 90 deletions

3365
accel_sdxl_gen_img.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -404,14 +404,14 @@ def train(args):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
example_tuple = (latents, batch["captions"])
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
)
# 指定ステップごとにモデルを保存
@@ -474,7 +474,7 @@ def train(args):
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
is_main_process = accelerator.is_main_process
if is_main_process:

View File

@@ -1485,6 +1485,7 @@ class ListPrompter:
def main(args):
if args.fp16:
dtype = torch.float16
elif args.bf16:
@@ -1492,6 +1493,8 @@ def main(args):
else:
dtype = torch.float32
device = get_preferred_device()
highres_fix = args.highres_fix_scale is not None
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
@@ -1521,9 +1524,10 @@ def main(args):
if is_sdxl:
if args.clip_skip is None:
args.clip_skip = 2
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype
)
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
text_encoders = [text_encoder1, text_encoder2]
@@ -1748,7 +1752,7 @@ def main(args):
logger.info(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module):
logger.info("import network module: {network_module}")
logger.info(f"import network module: {network_module}")
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
@@ -2508,7 +2512,7 @@ def main(args):
metadata.add_text("crop-left", str(crop_left))
if filename is not None:
fln = filename
fln = first_available_filename(args.outdir, filename) #Checks to make sure is not existing file, else returns first available sequential filename
else:
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
@@ -2586,7 +2590,8 @@ def main(args):
negative_scale = args.negative_scale
steps = args.steps
seed = None
seeds = None
if pi == 0:
seeds = None
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
@@ -2670,7 +2675,11 @@ def main(args):
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
if pi > 0 and len(raw_prompts) > 1: #Bypass od 2nd loop for dynamic prompts
continue
logger.info(f"{m}")
seeds = m.group(1).split(",")
seeds = [int(d.strip()) for d in seeds]
logger.info(f"seeds: {seeds}")
continue
@@ -2795,14 +2804,19 @@ def main(args):
m = re.match(r"f (.+)", parg, re.IGNORECASE)
if m: # filename
filename = m.group(1)
logger.info(f"filename: {filename}")
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
# override Deep Shrink
# override filename to add index number if more than one image per prompt
if filename is not None and (args.images_per_prompt > 1 or len(raw_prompts) > 1):
fileext = os.path.splitext(filename)
filename = fileext[0] + "_%s" % pi + fileext[1]
logger.info(f"filename: {filename}")
# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
@@ -2835,8 +2849,16 @@ def main(args):
# prepare seed
if seeds is not None: # given in prompt
# num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
if len(seeds) > 0:
# Previous implementation may result in unexpected behaviour when number of seeds is lesss than number of repeats. Last seed is taken for rest of repeated prompts. Add condition if last element is -1, to start randomizing seed.
if len(seeds) > 1:
seed = seeds.pop(0)
elif len(seeds) == 1:
if seeds[0] == -1:
seeds = None
else:
seed = seeds.pop(0)
else:
if args.iter_same_seed:
seed = iter_seed
@@ -2847,6 +2869,7 @@ def main(args):
seed = seed_random.randint(0, 2**32 - 1)
if args.interactive:
logger.info(f"seed: {seed}")
# logger.info(f"seed: {seed}") #debugging logger. Uncomment to verify if expected seed is added correctly.
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
@@ -2935,7 +2958,35 @@ def main(args):
logger.info("done!")
def first_available_filename(path, filename):
"""
Checks if filename is in use.
if filename is in use, appends a running number
e.g. filename = 'file.png':
file.png
file_1.png
file_2.png
Runs in log(n) time where n is the number of existing files in sequence
"""
i = 1
if not os.path.exists(os.path.join(path, filename)):
return filename
fileext = os.path.splitext(filename)
filename = fileext[0] + "_%s" + fileext[1]
# First do an exponential search
while os.path.exists(os.path.join(path,filename % i)):
i = i * 2
# Result lies somewhere in the interval (i/2..i]
# We call this interval (a..b] and narrow it down until a + 1 = b
a, b = (i // 2, i)
while a + 1 < b:
c = (a + b) // 2 # interval midpoint
a, b = (c, b) if os.path.exists(os.path.join(path,filename % c)) else (a, c)
return filename % b
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
@@ -3343,6 +3394,10 @@ def setup_parser() -> argparse.ArgumentParser:
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
)
parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16")
parser.add_argument(
"--full_bf16", action="store_true", help="Loading model in bf16"
)
# # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"

View File

@@ -91,10 +91,10 @@ def _load_target_model(
pipe = StableDiffusionXLPipeline.from_pretrained(
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
)
except EnvironmentError as ex:
except ValueError as ex:
if variant is not None:
logger.info("try to load fp32 model")
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
logger.info("try to load default model")
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=model_dtype, variant=None, tokenizer=None)
else:
raise ex
except EnvironmentError as ex:

View File

@@ -28,6 +28,7 @@ import random
import hashlib
import subprocess
from io import BytesIO
from accelerate.utils import gather_object
import toml
from tqdm import tqdm
@@ -1724,7 +1725,10 @@ class DreamBoothDataset(BaseDataset):
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append((info, subset))
if subset.num_repeats > 1:
info.num_repeats = 1
for i in range(subset.num_repeats):
reg_infos.append((info, subset))
else:
self.register_image(info, subset)
@@ -1735,6 +1739,7 @@ class DreamBoothDataset(BaseDataset):
self.num_train_images = num_train_images
logger.info(f"{num_reg_images} reg images.")
random.shuffle(reg_infos)
if num_train_images < num_reg_images:
logger.warning("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
@@ -1748,13 +1753,16 @@ class DreamBoothDataset(BaseDataset):
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
logger.info(f"Registering image: {info.absolute_path}")
n += info.num_repeats
else:
info.num_repeats += 1 # rewrite registered info
logger.info(f"Registering image: {info.absolute_path}")
n += 1
if n >= num_train_images:
break
first_loop = False
random.shuffle(reg_infos)
self.num_reg_images = num_reg_images
@@ -4720,7 +4728,7 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
logger.info(f"loading model for process {accelerator.state.local_process_index+1}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args,
@@ -5426,7 +5434,7 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict
RE_CAPTION_PROMPT = re.compile(r"(?i)__caption((\|)(.+?)?)?__")
def sample_images_common(
pipe_class,
accelerator: Accelerator,
@@ -5438,6 +5446,7 @@ def sample_images_common(
tokenizer,
text_encoder,
unet,
example_tuple=None,
prompt_replacement=None,
controlnet=None,
):
@@ -5468,7 +5477,7 @@ def sample_images_common(
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
vae.to(device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
@@ -5478,18 +5487,23 @@ def sample_images_common(
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
if distributed_state.is_main_process:
# Load prompts into prompts list on main process only
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
else:
prompts = [] # Init empty prompts list for sub processes.
# schedulers: dict = {} cannot find where this is used
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
@@ -5507,21 +5521,65 @@ def sample_images_common(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
pipeline.to(device)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
if example_tuple:
latents_list = []
for idx in range(len(example_tuple[1])):
latent_dict = {}
latent_dict["prompt"] = example_tuple[1][idx]
latent_dict["height"] = example_tuple[0].shape[2] * 8
latent_dict["width"] = example_tuple[0].shape[3] * 8
latent_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
latents_list.append(latent_dict)
distributed_state.wait_for_everyone()
latents_list = gather_object(latents_list)
save_dir = args.output_dir + "/sample"
if distributed_state.is_main_process:
#Create output folder and preprocess prompts on main process only.
os.makedirs(save_dir, exist_ok=True)
idx = 0
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
selected = ""
if '__caption' in prompt_dict.get("prompt"):
match_caption = RE_CAPTION_PROMPT.search(prompt_dict.get("prompt"))
if match_caption is not None:
if not example_tuple:
if match_caption.group(3) is not None:
caption_list = match_caption.group(3).split("|")
selected = random.choice(caption_list)
prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else f'Astronaut riding a horse on the moon')
logger.info(f"Backup prompt: {prompt_dict.get('prompt')}")
else:
while latents_list[idx]["prompt"] == '':
idx = (idx + 1) % len(latents_list)
if idx == 0:
break
prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), f'{latents_list[idx]["prompt"]}')
#logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}")
prompt_dict["height"] = latents_list[idx]["height"]
#logger.info(f"Original Image Height: {prompt_dict['height']}")
prompt_dict["width"] = latents_list[idx]["width"]
#logger.info(f"Original Image Width: {prompt_dict['width']}")
prompt_dict["original_lantent"] = latents_list[idx]["original_lantent"]
idx = (idx + 1) % len(latents_list)
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
logger.info(f"Current prompt: {prompts[i].get('prompt')}")
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
# save random state to restore later
rng_state = torch.get_rng_state()
@@ -5531,26 +5589,25 @@ def sample_images_common(
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
if distributed_state.num_processes > 1 and distributed_state.is_main_process:
per_process_prompts = [] # list of lists
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
prompts = []
# Flattening prompts for simplicity
for prompt in per_process_prompts:
prompts.extend(prompt)
distributed_state.wait_for_everyone()
prompts = gather_object(prompts)
with torch.no_grad():
with distributed_state.split_between_processes(prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists:
sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
# clear pipeline and cache to reduce vram usage
del pipeline
@@ -5565,6 +5622,42 @@ def sample_images_common(
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
def draw_text_on_image(text, max_width, text_color="black"):
from PIL import ImageDraw, ImageFont, Image
import textwrap
font = ImageFont.load_default()
space_width = font.getbbox(' ')[2]
font_size = 20
def wrap_text(text, font, max_width):
words = text.split(' ')
lines = []
current_line = ""
for word in words:
test_line = current_line + word + " "
if font.getbbox(test_line)[2] <= max_width:
current_line = test_line
else:
lines.append(current_line)
current_line = word + " "
lines.append(current_line)
return lines
lines = wrap_text(text, font, max_width - 10)
text_height = sum([font.getbbox(line)[3] - font.getbbox(line)[1] for line in lines]) + 20
text_image = Image.new('RGB', (max_width, text_height), 'white')
text_draw = ImageDraw.Draw(text_image)
y_text = 10
for line in lines:
bbox = text_draw.textbbox((0, 0), line, font=font)
height = bbox[3] - bbox[1]
text_draw.text((10, y_text), line, font=font, fill=text_color)
y_text += font_size
return text_image
def sample_image_inference(
accelerator: Accelerator,
@@ -5635,13 +5728,23 @@ def sample_image_inference(
controlnet=controlnet,
controlnet_image=controlnet_image,
)
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
image = pipeline.latents_to_image(latents)[0]
if "original_lantent" in prompt_dict:
#Prevent out of VRAM error
clean_memory_on_device(accelerator.device)
original_latent = prompt_dict.get("original_lantent").to(device=accelerator.device)
original_image = pipeline.latents_to_image(original_latent)[0]
text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2)
new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height))
new_image.paste(original_image, (0, text_image.height))
new_image.paste(image, (original_image.width, text_image.height))
new_image.paste(text_image, (0, 0))
image = new_image
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough

View File

@@ -67,7 +67,7 @@ def setup_logging(args=None, log_level=None, reset=False):
if handler is None:
handler = logging.StreamHandler(sys.stdout) # same as print
handler.propagate = False
handler.propagate = True
formatter = logging.Formatter(
fmt="%(message)s",

View File

@@ -13,6 +13,7 @@ import math
import os
import random
import re
import gc
import diffusers
import numpy as np
@@ -1489,12 +1490,16 @@ def main(args):
files = glob.glob(args.ckpt)
if len(files) == 1:
args.ckpt = files[0]
device = get_preferred_device()
logger.info(f"preferred device: {device}")
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype
)
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
mem_eff = not (args.xformers or args.sdpa)
@@ -1621,7 +1626,7 @@ def main(args):
# scheduler.config.clip_sample = True
# deviceを決定する
device = get_preferred_device()
# custom pipelineをコピったやつを生成する
if args.vae_slices:
@@ -1651,13 +1656,9 @@ def main(args):
vae.to(vae_dtype).to(device)
vae.eval()
text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
text_encoder1.eval()
text_encoder2.eval()
unet.eval()
# networkを組み込む
if args.network_module:
networks = []
@@ -2809,7 +2810,9 @@ def main(args):
num_sub_prompts,
),
)
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
logger.info("Does this run? When number of prompts less than batch?")
process_batch(batch_data, highres_fix)
batch_data.clear()
@@ -2817,8 +2820,9 @@ def main(args):
if len(batch_data) == args.batch_size:
prev_image = process_batch(batch_data, highres_fix)[0]
batch_data.clear()
logger.info(f"Global Step: {global_step}")
global_step += 1
prompt_index += 1
@@ -3194,6 +3198,10 @@ def setup_parser() -> argparse.ArgumentParser:
help="unsharp mask parameters for Gradual Latent: ksize, sigma, strength, target-x (1 means True). `3,0.5,0.5,1` or `3,1.0,1.0,0` is recommended /"
+ " Gradual Latentのunsharp maskのパラメータ: ksize, sigma, strength, target-x. `3,0.5,0.5,1` または `3,1.0,1.0,0` が推奨",
)
parser.add_argument("--full_fp16", action="store_true", help="Loading model in fp16")
parser.add_argument(
"--full_bf16", action="store_true", help="Loading model in bf16"
)
# # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"

View File

@@ -740,6 +740,7 @@ def train(args):
accelerator.backward(loss)
if not (args.fused_backward_pass or args.fused_optimizer_groups):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
@@ -757,6 +758,8 @@ def train(args):
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
example_tuple = (latents, batch["captions"])
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -772,6 +775,7 @@ def train(args):
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
example_tuple,
)
# 指定ステップごとにモデルを保存
@@ -854,6 +858,7 @@ def train(args):
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
example_tuple,
)
is_main_process = accelerator.is_main_process

View File

@@ -163,8 +163,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -82,9 +82,9 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):

View File

@@ -388,14 +388,14 @@ def train(args):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
example_tuple = (latents, batch["captions"])
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
)
# 指定ステップごとにモデルを保存
@@ -459,7 +459,7 @@ def train(args):
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
is_main_process = accelerator.is_main_process
if is_main_process:

View File

@@ -39,6 +39,7 @@ from library.custom_train_functions import (
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
from accelerate.utils import gather_object, gather
setup_logging()
import logging
@@ -131,8 +132,8 @@ class NetworkTrainer:
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -1030,11 +1031,15 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = None, None, None
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
example_tuple = (latents.detach().clone(), batch["captions"])
if args.sample_every_n_steps is not None and global_step % args.sample_every_n_steps == 0:
accelerator.wait_for_everyone()
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1090,7 +1095,9 @@ class NetworkTrainer:
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
if args.sample_every_n_epochs is not None and (epoch + 1)% args.sample_every_n_epochs == 0:
accelerator.wait_for_everyone()
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
# end of epoch
@@ -1233,6 +1240,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ未指定時と同じ。initial_epochを上書きする",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")

View File

@@ -122,9 +122,9 @@ class TextualInversionTrainer:
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -627,6 +627,7 @@ class TextualInversionTrainer:
index_no_updates
]
example_tuple = (latents, captions)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -642,6 +643,7 @@ class TextualInversionTrainer:
tokenizer_or_list,
text_encoder_or_list,
unet,
example_tuple,
prompt_replacement,
)
@@ -714,7 +716,6 @@ class TextualInversionTrainer:
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(
accelerator,
args,
@@ -725,6 +726,7 @@ class TextualInversionTrainer:
tokenizer_or_list,
text_encoder_or_list,
unet,
example_tuple,
prompt_replacement,
)