From 806a6237fb44557740c27cb51d15b8d837c26dce Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 18 Feb 2024 21:57:16 +0900 Subject: [PATCH] minor fixes --- README.md | 2 +- stable_cascade_gen_img.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 03b33c36..c9a13afc 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ The code for training the Text Encoder is also written, but it is untested. ### Command line sample ```batch -accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt +accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight ``` ### About the dataset for fine tuning diff --git a/stable_cascade_gen_img.py b/stable_cascade_gen_img.py index abffbe60..b7e5fe4e 100644 --- a/stable_cascade_gen_img.py +++ b/stable_cascade_gen_img.py @@ -59,6 +59,13 @@ def main(args): stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device) stage_a.eval().requires_grad_(False) + # previewer + if args.previewer_checkpoint_path is not None: + previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device) + previewer.eval().requires_grad_(False) + else: + previewer = None + # 謎のクラス gdf gdf_c = sc.GDF( schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]), @@ -221,6 +228,18 @@ def main(args): conditions_b["effnet"] = sampled_c unconditions_b["effnet"] = torch.zeros_like(sampled_c) + if previewer is not None: + with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): + preview = previewer(sampled_c) + preview = preview.clamp(0, 1) + preview = preview.permute(0, 2, 3, 1).squeeze(0) + preview = preview.detach().float().cpu().numpy() + preview = Image.fromarray((preview * 255).astype(np.uint8)) + + timestamp_str = time.strftime("%Y%m%d_%H%M%S") + os.makedirs(args.outdir, exist_ok=True) + preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png")) + if args.lowvram: generator_c = generator_c.to(loading_device) device_utils.clean_memory_on_device(device) @@ -274,6 +293,7 @@ if __name__ == "__main__": sc_utils.add_stage_a_arguments(parser) sc_utils.add_stage_b_arguments(parser) sc_utils.add_stage_c_arguments(parser) + sc_utils.add_previewer_arguments(parser) sc_utils.add_text_model_arguments(parser) parser.add_argument("--bf16", action="store_true") parser.add_argument("--fp16", action="store_true")