From 35778f021897796410372aed8540547ba317c2a3 Mon Sep 17 00:00:00 2001 From: minux302 Date: Sun, 17 Nov 2024 11:09:05 +0000 Subject: [PATCH] fix sample_images type --- flux_train_control_net.py | 31 ++++++++++++++----------------- library/flux_train_utils.py | 2 +- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 205ff6b6..791900d1 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -444,8 +444,7 @@ def train(args): clean_memory_on_device(accelerator.device) - # if args.deepspeed: - if True: + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -644,7 +643,6 @@ def train(args): t5_attn_mask = None with accelerator.autocast(): - print("control start") block_samples, block_single_samples = controlnet( img=packed_noisy_model_input, img_ids=img_ids, @@ -656,8 +654,6 @@ def train(args): guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - print("control end") - print("dit start") # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) model_pred = flux( img=packed_noisy_model_input, @@ -763,18 +759,19 @@ def train(args): accelerator.wait_for_everyone() optimizer_eval_fn() - if args.save_every_n_epochs is not None: - if accelerator.is_main_process: - flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( - args, - True, - accelerator, - save_dtype, - epoch, - num_train_epochs, - global_step, - accelerator.unwrap_model(flux), - ) + # TODO: save cn models + # if args.save_every_n_epochs is not None: + # if accelerator.is_main_process: + # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( + # args, + # True, + # accelerator, + # save_dtype, + # epoch, + # num_train_epochs, + # global_step, + # accelerator.unwrap_model(flux), + # ) flux_train_utils.sample_images( accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index d82bde91..de2ee030 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -235,7 +235,7 @@ def sample_image_inference( with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) - x = x.float() + # x = x.float() # TODO: check x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) # latent to image