adding example generation

This commit is contained in:
DKnight54
2025-01-29 18:46:52 +08:00
parent b1b1c19be1
commit 39a375139a
8 changed files with 83 additions and 18 deletions

View File

@@ -404,14 +404,14 @@ def train(args):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
example_tuple = (latents, batch["captions"])
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
)
# 指定ステップごとにモデルを保存
@@ -474,7 +474,7 @@ def train(args):
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
is_main_process = accelerator.is_main_process
if is_main_process:

View File

@@ -5431,6 +5431,7 @@ def sample_images_common(
tokenizer,
text_encoder,
unet,
example_tuple=None,
prompt_replacement=None,
controlnet=None,
):
@@ -5527,7 +5528,18 @@ def sample_images_common(
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
idx = 0
for prompt_dict in prompts:
if '__caption__' in prompt_dict.get("prompt") and example_tuple:
while example_tuple[1][idx] == '':
idx = (idx + 1) % len(example_tuple[1])
if idx == 0:
break
prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]')
prompt_dict["height"] = example_tuple[0].shape[2] * 8
prompt_dict["width"] = example_tuple[0].shape[3] * 8
prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
idx = (idx + 1) % len(example_tuple[1])
sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
@@ -5558,6 +5570,42 @@ def sample_images_common(
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
def draw_text_on_image(text, max_width, text_color="black"):
from PIL import ImageDraw, ImageFont, Image
import textwrap
font = ImageFont.load_default()
space_width = font.getbbox(' ')[2]
font_size = 20
def wrap_text(text, font, max_width):
words = text.split(' ')
lines = []
current_line = ""
for word in words:
test_line = current_line + word + " "
if font.getbbox(test_line)[2] <= max_width:
current_line = test_line
else:
lines.append(current_line)
current_line = word + " "
lines.append(current_line)
return lines
lines = wrap_text(text, font, max_width - 10)
text_height = sum([font.getbbox(line)[3] - font.getbbox(line)[1] for line in lines]) + 20
text_image = Image.new('RGB', (max_width, text_height), 'white')
text_draw = ImageDraw.Draw(text_image)
y_text = 10
for line in lines:
bbox = text_draw.textbbox((0, 0), line, font=font)
height = bbox[3] - bbox[1]
text_draw.text((10, y_text), line, font=font, fill=text_color)
y_text += font_size
return text_image
def sample_image_inference(
accelerator: Accelerator,
@@ -5634,7 +5682,16 @@ def sample_image_inference(
torch.cuda.empty_cache()
image = pipeline.latents_to_image(latents)[0]
if "original_lantent" in prompt_dict:
original_latent = prompt_dict.get("original_lantent")
original_image = pipeline.latents_to_image(original_latent)[0]
text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2)
new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height))
new_image.paste(original_image, (0, text_image.height))
new_image.paste(image, (original_image.width, text_image.height))
new_image.paste(text_image, (0, 0))
image = new_image
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough

View File

@@ -740,6 +740,7 @@ def train(args):
accelerator.backward(loss)
if not (args.fused_backward_pass or args.fused_optimizer_groups):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
@@ -757,6 +758,8 @@ def train(args):
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
example_tuple = (latents, batch["captions"])
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -772,6 +775,7 @@ def train(args):
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
example_tuple,
)
# 指定ステップごとにモデルを保存
@@ -854,6 +858,7 @@ def train(args):
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
example_tuple,
)
is_main_process = accelerator.is_main_process

View File

@@ -164,8 +164,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -81,9 +81,9 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):

View File

@@ -388,14 +388,14 @@ def train(args):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
example_tuple = (latents, batch["captions"])
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple
)
# 指定ステップごとにモデルを保存
@@ -459,7 +459,7 @@ def train(args):
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
is_main_process = accelerator.is_main_process
if is_main_process:

View File

@@ -131,8 +131,8 @@ class NetworkTrainer:
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple)
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -1022,11 +1022,12 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = None, None, None
# Checks if the accelerator has performed an optimization step behind the scenes
example_tuple = (latents, batch["captions"])
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1082,7 +1083,7 @@ class NetworkTrainer:
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple)
# end of epoch

View File

@@ -122,9 +122,9 @@ class TextualInversionTrainer:
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement):
train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -627,6 +627,7 @@ class TextualInversionTrainer:
index_no_updates
]
example_tuple = (latents, captions)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -642,6 +643,7 @@ class TextualInversionTrainer:
tokenizer_or_list,
text_encoder_or_list,
unet,
example_tuple,
prompt_replacement,
)
@@ -714,7 +716,6 @@ class TextualInversionTrainer:
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
self.sample_images(
accelerator,
args,
@@ -725,6 +726,7 @@ class TextualInversionTrainer:
tokenizer_or_list,
text_encoder_or_list,
unet,
example_tuple,
prompt_replacement,
)