diff --git a/fine_tune.py b/fine_tune.py index c79f97d2..c59ffa14 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -404,14 +404,14 @@ def train(args): optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - + example_tuple = (latents, batch["captions"]) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple ) # 指定ステップごとにモデルを保存 @@ -474,7 +474,7 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/library/train_util.py b/library/train_util.py index 100ef475..0df9c1fc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5431,6 +5431,7 @@ def sample_images_common( tokenizer, text_encoder, unet, + example_tuple=None, prompt_replacement=None, controlnet=None, ): @@ -5527,7 +5528,18 @@ def sample_images_common( if distributed_state.num_processes <= 1: # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): + idx = 0 for prompt_dict in prompts: + if '__caption__' in prompt_dict.get("prompt") and example_tuple: + while example_tuple[1][idx] == '': + idx = (idx + 1) % len(example_tuple[1]) + if idx == 0: + break + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]') + prompt_dict["height"] = example_tuple[0].shape[2] * 8 + prompt_dict["width"] = example_tuple[0].shape[3] * 8 + prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + idx = (idx + 1) % len(example_tuple[1]) sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) @@ -5558,6 +5570,42 @@ def sample_images_common( torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) +def draw_text_on_image(text, max_width, text_color="black"): + from PIL import ImageDraw, ImageFont, Image + import textwrap + + font = ImageFont.load_default() + space_width = font.getbbox(' ')[2] + font_size = 20 + + def wrap_text(text, font, max_width): + words = text.split(' ') + lines = [] + current_line = "" + for word in words: + test_line = current_line + word + " " + if font.getbbox(test_line)[2] <= max_width: + current_line = test_line + else: + lines.append(current_line) + current_line = word + " " + lines.append(current_line) + return lines + + lines = wrap_text(text, font, max_width - 10) + text_height = sum([font.getbbox(line)[3] - font.getbbox(line)[1] for line in lines]) + 20 + text_image = Image.new('RGB', (max_width, text_height), 'white') + text_draw = ImageDraw.Draw(text_image) + + y_text = 10 + for line in lines: + bbox = text_draw.textbbox((0, 0), line, font=font) + height = bbox[3] - bbox[1] + text_draw.text((10, y_text), line, font=font, fill=text_color) + y_text += font_size + + return text_image + def sample_image_inference( accelerator: Accelerator, @@ -5634,7 +5682,16 @@ def sample_image_inference( torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] - + if "original_lantent" in prompt_dict: + original_latent = prompt_dict.get("original_lantent") + original_image = pipeline.latents_to_image(original_latent)[0] + text_image = draw_text_on_image(f"caption: {prompt}", image.width * 2) + new_image = Image.new('RGB', (original_image.width + image.width, original_image.height + text_image.height)) + new_image.paste(original_image, (0, text_image.height)) + new_image.paste(image, (original_image.width, text_image.height)) + new_image.paste(text_image, (0, 0)) + image = new_image + # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list # but adding 'enum' to the filename should be enough diff --git a/sdxl_train.py b/sdxl_train.py index b533b274..7779d226 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -740,6 +740,7 @@ def train(args): accelerator.backward(loss) + if not (args.fused_backward_pass or args.fused_optimizer_groups): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] @@ -757,6 +758,8 @@ def train(args): for i in range(1, len(optimizers)): lr_schedulers[i].step() + + example_tuple = (latents, batch["captions"]) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) @@ -772,6 +775,7 @@ def train(args): [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet, + example_tuple, ) # 指定ステップごとにモデルを保存 @@ -854,6 +858,7 @@ def train(args): [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet, + example_tuple, ) is_main_process = accelerator.is_main_process diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 83969bb1..3eafc152 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -164,8 +164,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None): + sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple) def setup_parser() -> argparse.ArgumentParser: diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index 5df739e2..de75a0aa 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -81,9 +81,9 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement): sdxl_train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): diff --git a/train_db.py b/train_db.py index e7cf3cde..d4f558b1 100644 --- a/train_db.py +++ b/train_db.py @@ -388,14 +388,14 @@ def train(args): optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - + example_tuple = (latents, batch["captions"]) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet + accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple ) # 指定ステップごとにモデルを保存 @@ -459,7 +459,7 @@ def train(args): vae, ) - train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) is_main_process = accelerator.is_main_process if is_main_process: diff --git a/train_network.py b/train_network.py index 7bf125dc..2554b230 100644 --- a/train_network.py +++ b/train_network.py @@ -131,8 +131,8 @@ class NetworkTrainer: if param.grad is not None: param.grad = accelerator.reduce(param.grad, reduction="mean") - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None): + train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple) def train(self, args): session_id = random.randint(0, 2**32) @@ -1022,11 +1022,12 @@ class NetworkTrainer: keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes + example_tuple = (latents, batch["captions"]) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: @@ -1082,7 +1083,7 @@ class NetworkTrainer: if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, example_tuple) # end of epoch diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 37349da7..067d44ca 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -122,9 +122,9 @@ class TextualInversionTrainer: noise_pred = unet(noisy_latents, timesteps, text_conds).sample return noise_pred - def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement): + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple=None, prompt_replacement): train_util.sample_images( - accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement + accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, example_tuple, prompt_replacement ) def save_weights(self, file, updated_embs, save_dtype, metadata): @@ -627,6 +627,7 @@ class TextualInversionTrainer: index_no_updates ] + example_tuple = (latents, captions) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) @@ -642,6 +643,7 @@ class TextualInversionTrainer: tokenizer_or_list, text_encoder_or_list, unet, + example_tuple, prompt_replacement, ) @@ -714,7 +716,6 @@ class TextualInversionTrainer: if args.save_state: train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - self.sample_images( accelerator, args, @@ -725,6 +726,7 @@ class TextualInversionTrainer: tokenizer_or_list, text_encoder_or_list, unet, + example_tuple, prompt_replacement, )