mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix sample_images type
This commit is contained in:
@@ -444,8 +444,7 @@ def train(args):
|
|||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
# if args.deepspeed:
|
if args.deepspeed:
|
||||||
if True:
|
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
||||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@@ -644,7 +643,6 @@ def train(args):
|
|||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
print("control start")
|
|
||||||
block_samples, block_single_samples = controlnet(
|
block_samples, block_single_samples = controlnet(
|
||||||
img=packed_noisy_model_input,
|
img=packed_noisy_model_input,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
@@ -656,8 +654,6 @@ def train(args):
|
|||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
txt_attention_mask=t5_attn_mask,
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
print("control end")
|
|
||||||
print("dit start")
|
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = flux(
|
model_pred = flux(
|
||||||
img=packed_noisy_model_input,
|
img=packed_noisy_model_input,
|
||||||
@@ -763,18 +759,19 @@ def train(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
if args.save_every_n_epochs is not None:
|
# TODO: save cn models
|
||||||
if accelerator.is_main_process:
|
# if args.save_every_n_epochs is not None:
|
||||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
# if accelerator.is_main_process:
|
||||||
args,
|
# flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||||
True,
|
# args,
|
||||||
accelerator,
|
# True,
|
||||||
save_dtype,
|
# accelerator,
|
||||||
epoch,
|
# save_dtype,
|
||||||
num_train_epochs,
|
# epoch,
|
||||||
global_step,
|
# num_train_epochs,
|
||||||
accelerator.unwrap_model(flux),
|
# global_step,
|
||||||
)
|
# accelerator.unwrap_model(flux),
|
||||||
|
# )
|
||||||
|
|
||||||
flux_train_utils.sample_images(
|
flux_train_utils.sample_images(
|
||||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ def sample_image_inference(
|
|||||||
with accelerator.autocast(), torch.no_grad():
|
with accelerator.autocast(), torch.no_grad():
|
||||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||||
|
|
||||||
x = x.float()
|
# x = x.float() # TODO: check
|
||||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||||
|
|
||||||
# latent to image
|
# latent to image
|
||||||
|
|||||||
Reference in New Issue
Block a user