fix sample_images type

This commit is contained in:
minux302
2024-11-17 11:09:05 +00:00
parent b2660bbe74
commit 35778f0218
2 changed files with 15 additions and 18 deletions

View File

@@ -444,8 +444,7 @@ def train(args):
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
# if args.deepspeed: if args.deepspeed:
if True:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet) ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -644,7 +643,6 @@ def train(args):
t5_attn_mask = None t5_attn_mask = None
with accelerator.autocast(): with accelerator.autocast():
print("control start")
block_samples, block_single_samples = controlnet( block_samples, block_single_samples = controlnet(
img=packed_noisy_model_input, img=packed_noisy_model_input,
img_ids=img_ids, img_ids=img_ids,
@@ -656,8 +654,6 @@ def train(args):
guidance=guidance_vec, guidance=guidance_vec,
txt_attention_mask=t5_attn_mask, txt_attention_mask=t5_attn_mask,
) )
print("control end")
print("dit start")
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux( model_pred = flux(
img=packed_noisy_model_input, img=packed_noisy_model_input,
@@ -763,18 +759,19 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
optimizer_eval_fn() optimizer_eval_fn()
if args.save_every_n_epochs is not None: # TODO: save cn models
if accelerator.is_main_process: # if args.save_every_n_epochs is not None:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise( # if accelerator.is_main_process:
args, # flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
True, # args,
accelerator, # True,
save_dtype, # accelerator,
epoch, # save_dtype,
num_train_epochs, # epoch,
global_step, # num_train_epochs,
accelerator.unwrap_model(flux), # global_step,
) # accelerator.unwrap_model(flux),
# )
flux_train_utils.sample_images( flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs

View File

@@ -235,7 +235,7 @@ def sample_image_inference(
with accelerator.autocast(), torch.no_grad(): with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = x.float() # x = x.float() # TODO: check
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width) x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image # latent to image