diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 11a756ea..c0a8c6c1 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2120,7 +2120,7 @@ def main(args): iter_seed = random.randint(0, 0x7FFFFFFF) # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + def process_batch(batch: List[BatchData], distributed_state, highres_fix, highres_1st=False): batch_size = len(batch) # highres_fixの処理 @@ -2170,7 +2170,7 @@ def main(args): batch_1st.append(BatchData(is_1st_latent, global_count, base, ext_1st)) pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) + images_1st = process_batch(batch_1st, distributed_state, True, True) # 2nd stageのバッチを作成して以下処理する logger.info("process 2nd stage") @@ -2381,60 +2381,72 @@ def main(args): return images # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ds_str = time.strftime("%Y%m%d", time.localtime()) - ts_str = time.strftime("%H%M%S", time.localtime()) - for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, global_counter, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - if raw_prompt is not None: - metadata.add_text("raw-prompt", raw_prompt) - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + distributed_state.wait_for_everyone() + all_images=gather_object(images) + all_global_counter = gather_object(global_counter) + all_prompts = gather_object(prompts) + all_negative_prompts = gather_object(negative_prompts) + all_seeds = gather_object(seeds) + all_clip_prompts = gather_object(clip_prompts) + all_raw_prompts = gather_object(raw_prompts) + all_init_images = gather_object(init_images) + if distributed_state.is_main_process: + + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ds_str = time.strftime("%Y%m%d", time.localtime()) + ts_str = time.strftime("%H%M%S", time.localtime()) + for i, (image, globalcount, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( + zip(all_images, all_global_counter, all_prompts, all_negative_prompts, all_seeds, all_clip_prompts, all_raw_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + if raw_prompt is not None: + metadata.add_text("raw-prompt", raw_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{globalcount:02d}_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ds_str}_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - logger.error( - "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" - ) + fln = f"im_{ds_str}_{ts_str}_{globalcount:02d}_{highres_prefix}{i:03d}_{seed}.png" + logger.info(f"Saving image {global_count}: {fln}\nPrompt: {prompt}") + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + logger.error( + "opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません" + ) + distributed_state.wait_for_everyone() return images # 画像生成のプロンプトが一周するまでのループ @@ -2922,7 +2934,7 @@ def main(args): logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:\nbatch_list:") for i in range(len(batches[j])): logger.info(f"Image: {batches[j][i].global_count}\nDevice {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}") - prev_image = process_batch(batch_list[j], highres_fix)[0] + prev_image = process_batch(batch_list[j], distributed_state, highres_fix)[0] distributed_state.wait_for_everyone() #for i in range(len(data_loader)):