mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix sample_images type
This commit is contained in:
@@ -444,8 +444,7 @@ def train(args):
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# if args.deepspeed:
|
||||
if True:
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -644,7 +643,6 @@ def train(args):
|
||||
t5_attn_mask = None
|
||||
|
||||
with accelerator.autocast():
|
||||
print("control start")
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
@@ -656,8 +654,6 @@ def train(args):
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
print("control end")
|
||||
print("dit start")
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
img=packed_noisy_model_input,
|
||||
@@ -763,18 +759,19 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(flux),
|
||||
)
|
||||
# TODO: save cn models
|
||||
# if args.save_every_n_epochs is not None:
|
||||
# if accelerator.is_main_process:
|
||||
# flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
# args,
|
||||
# True,
|
||||
# accelerator,
|
||||
# save_dtype,
|
||||
# epoch,
|
||||
# num_train_epochs,
|
||||
# global_step,
|
||||
# accelerator.unwrap_model(flux),
|
||||
# )
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
|
||||
Reference in New Issue
Block a user