mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 08:21:46 +00:00
adding example generation
This commit is contained in:
@@ -404,14 +404,14 @@ def train(args):
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
example_tuple = (latents, batch["captions"])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -474,7 +474,7 @@ def train(args):
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
|
||||
@@ -5431,6 +5431,7 @@ def sample_images_common(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
unet,
|
||||
example_tuple=None,
|
||||
prompt_replacement=None,
|
||||
controlnet=None,
|
||||
):
|
||||
@@ -5527,7 +5528,18 @@ def sample_images_common(
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
idx = 0
|
||||
for prompt_dict in prompts:
|
||||
if '__caption__' in prompt_dict.get("prompt") and example_tuple:
|
||||
while example_tuple[1][idx] == '':
|
||||
idx = (idx + 1) % len(example_tuple[1])
|
||||
if idx == 0:
|
||||
break
|
||||
prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]')
|
||||
prompt_dict["height"] = example_tuple[0].shape[2] * 8
|
||||
prompt_dict["width"] = example_tuple[0].shape[3] * 8
|
||||
prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
|
||||
idx = (idx + 1) % len(example_tuple[1])
|
||||
sample_image_inference(
|
||||
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
|
||||
)
|
||||
@@ -5558,6 +5570,42 @@ def sample_images_common(
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
def draw_text_on_image(text, max_width, text_color="black"):
|
||||
from PIL import ImageDraw, ImageFont, Image
|
||||
import textwrap
|
||||
|
||||
font = ImageFont.load_default()
|
||||
space_width = font.getbbox(' ')[2]
|
||||
font_size = 20
|
||||
|
||||
def wrap_text(text, font, max_width):
|
||||
words = text.split(' ')
|
||||
lines = []
|
||||
current_line = ""
|
||||
for word in words:
|
||||
test_line = current_line + word + " "
|
||||
if font.getbbox(test_line)[2] <= max_width:
|
||||
current_line = test_line
|
||||
else:
|
||||
lines.append(current_line)
|
||||
current_line = word + " "
|
||||
lines.append(current_line)
|
||||
return lines
|
||||
|
||||
lines = wrap_text(text, font, max_width - 10)
|
||||
text_height = sum([font.getbbox(line)[3] - font.getbbox(line)[1] for line in lines]) + 20
|
||||
text_image = Image.new('RGB', (max_width, text_height), 'white')
|
||||
text_draw = ImageDraw.Draw(text_image)
|
||||
|
||||
y_text = 10
|
||||
for line in lines:
|
||||
bbox = text_draw.textbbox((0, 0), line, font=font)
|
||||
height = bbox[3] - bbox[1]
|
||||
text_draw.text((10, y_text), line, font=font, fill=text_color)
|
||||
y_text += font_size
|
||||
|
||||
return text_image
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
@@ -5634,7 +5682,16 @@ def sample_image_inference(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image = pipeline.latents_to_image(latents)[0]
|
||||
|
||||
if "original_lantent" in prompt_dict:
|
||||
original_latent = prompt_dict.get("original_lantent")
|
||||
original_image = pipeline.latents_to_image(original_latent)[0]
|
||||
text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2)
|
||||
new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height))
|
||||
new_image.paste(original_image, (0, text_image.height))
|
||||
new_image.paste(image, (original_image.width, text_image.height))
|
||||
new_image.paste(text_image, (0, 0))
|
||||
image = new_image
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
|
||||
@@ -740,6 +740,7 @@ def train(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
|
||||
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
@@ -757,6 +758,8 @@ def train(args):
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
|
||||
example_tuple = (latents, batch["captions"])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -772,6 +775,7 @@ def train(args):
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
example_tuple,
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -854,6 +858,7 @@ def train(args):
|
||||
[tokenizer1, tokenizer2],
|
||||
[text_encoder1, text_encoder2],
|
||||
unet,
|
||||
example_tuple,
|
||||
)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
@@ -164,8 +164,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
|
||||
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -81,9 +81,9 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||
|
||||
@@ -388,14 +388,14 @@ def train(args):
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
example_tuple = (latents, batch["captions"])
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_util.sample_images(
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -459,7 +459,7 @@ def train(args):
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
|
||||
@@ -131,8 +131,8 @@ class NetworkTrainer:
|
||||
if param.grad is not None:
|
||||
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1022,11 +1022,12 @@ class NetworkTrainer:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
example_tuple = (latents, batch["captions"])
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
@@ -1082,7 +1083,7 @@ class NetworkTrainer:
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
|
||||
|
||||
# end of epoch
|
||||
|
||||
|
||||
@@ -122,9 +122,9 @@ class TextualInversionTrainer:
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
|
||||
train_util.sample_images(
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
|
||||
)
|
||||
|
||||
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
||||
@@ -627,6 +627,7 @@ class TextualInversionTrainer:
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
example_tuple = (latents, captions)
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -642,6 +643,7 @@ class TextualInversionTrainer:
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
example_tuple,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
@@ -714,7 +716,6 @@ class TextualInversionTrainer:
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
||||
|
||||
self.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
@@ -725,6 +726,7 @@ class TextualInversionTrainer:
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
example_tuple,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user