mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
work cn load and validation
This commit is contained in:
@@ -266,7 +266,7 @@ def train(args):
|
||||
flux.to(accelerator.device)
|
||||
|
||||
# load controlnet
|
||||
controlnet = flux_utils.load_controlnet()
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
controlnet.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -568,7 +568,7 @@ def train(args):
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
|
||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
@@ -718,7 +718,7 @@ def train(args):
|
||||
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -774,7 +774,7 @@ def train(args):
|
||||
# )
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
@@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--conditioning_data_dir",
|
||||
# type=str,
|
||||
# default=None,
|
||||
# help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
# )
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user