mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Unify controlnet parameters name and change scripts name. (#1821)
* Update sd3_train.py
* add freeze block lr
* Update train_util.py
* update
* Revert "add freeze block lr"
This reverts commit 8b1653548f.
# Conflicts:
# library/train_util.py
# sd3_train.py
* use same control net model path
* use controlnet_model_name_or_path
This commit is contained in:
@@ -265,7 +265,7 @@ def train(args):
|
|||||||
# load controlnet
|
# load controlnet
|
||||||
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
|
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
|
||||||
controlnet = flux_utils.load_controlnet(
|
controlnet = flux_utils.load_controlnet(
|
||||||
args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
|
args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
|
||||||
)
|
)
|
||||||
controlnet.train()
|
controlnet.train()
|
||||||
|
|
||||||
|
|||||||
@@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--controlnet",
|
"--controlnet_model_name_or_path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||||
|
|||||||
@@ -184,12 +184,12 @@ def train(args):
|
|||||||
|
|
||||||
# make control net
|
# make control net
|
||||||
logger.info("make ControlNet")
|
logger.info("make ControlNet")
|
||||||
if args.controlnet_model_path:
|
if args.controlnet_model_name_or_path:
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
control_net = SdxlControlNet()
|
control_net = SdxlControlNet()
|
||||||
|
|
||||||
logger.info(f"load ControlNet from {args.controlnet_model_path}")
|
logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}")
|
||||||
filename = args.controlnet_model_path
|
filename = args.controlnet_model_name_or_path
|
||||||
if os.path.splitext(filename)[1] == ".safetensors":
|
if os.path.splitext(filename)[1] == ".safetensors":
|
||||||
state_dict = load_file(filename)
|
state_dict = load_file(filename)
|
||||||
else:
|
else:
|
||||||
@@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--controlnet_model_path",
|
"--controlnet_model_name_or_path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||||
|
|||||||
Reference in New Issue
Block a user