work cn load and validation

This commit is contained in:
minux302
2024-11-18 12:47:01 +00:00
parent 35778f0218
commit 4dd4cd6ec8
4 changed files with 37 additions and 32 deletions

View File

@@ -266,7 +266,7 @@ def train(args):
flux.to(accelerator.device)
# load controlnet
controlnet = flux_utils.load_controlnet()
controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
controlnet.train()
if args.gradient_checkpointing:
@@ -568,7 +568,7 @@ def train(args):
# For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
@@ -718,7 +718,7 @@ def train(args):
optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
)
# 指定ステップごとにモデルを保存
@@ -774,7 +774,7 @@ def train(args):
# )
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
)
optimizer_train_fn()
@@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
# parser.add_argument(
# "--conditioning_data_dir",
# type=str,
# default=None,
# help="conditioning data directory / 条件付けデータのディレクトリ",
# )
return parser