mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix sample_images type
This commit is contained in:
@@ -444,8 +444,7 @@ def train(args):
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# if args.deepspeed:
|
||||
if True:
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -644,7 +643,6 @@ def train(args):
|
||||
t5_attn_mask = None
|
||||
|
||||
with accelerator.autocast():
|
||||
print("control start")
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
@@ -656,8 +654,6 @@ def train(args):
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
print("control end")
|
||||
print("dit start")
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
img=packed_noisy_model_input,
|
||||
@@ -763,18 +759,19 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(flux),
|
||||
)
|
||||
# TODO: save cn models
|
||||
# if args.save_every_n_epochs is not None:
|
||||
# if accelerator.is_main_process:
|
||||
# flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
# args,
|
||||
# True,
|
||||
# accelerator,
|
||||
# save_dtype,
|
||||
# epoch,
|
||||
# num_train_epochs,
|
||||
# global_step,
|
||||
# accelerator.unwrap_model(flux),
|
||||
# )
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
|
||||
@@ -235,7 +235,7 @@ def sample_image_inference(
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
|
||||
x = x.float()
|
||||
# x = x.float() # TODO: check
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# latent to image
|
||||
|
||||
Reference in New Issue
Block a user