fix sample_images type

This commit is contained in:
minux302
2024-11-17 11:09:05 +00:00
parent b2660bbe74
commit 35778f0218
2 changed files with 15 additions and 18 deletions

View File

@@ -444,8 +444,7 @@ def train(args):
clean_memory_on_device(accelerator.device)
# if args.deepspeed:
if True:
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
@@ -644,7 +643,6 @@ def train(args):
t5_attn_mask = None
with accelerator.autocast():
print("control start")
block_samples, block_single_samples = controlnet(
img=packed_noisy_model_input,
img_ids=img_ids,
@@ -656,8 +654,6 @@ def train(args):
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
print("control end")
print("dit start")
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
@@ -763,18 +759,19 @@ def train(args):
accelerator.wait_for_everyone()
optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
)
# TODO: save cn models
# if args.save_every_n_epochs is not None:
# if accelerator.is_main_process:
# flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
# args,
# True,
# accelerator,
# save_dtype,
# epoch,
# num_train_epochs,
# global_step,
# accelerator.unwrap_model(flux),
# )
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs