diff --git a/train_network.py b/train_network.py index 7861e740..aa42a3bf 100644 --- a/train_network.py +++ b/train_network.py @@ -24,7 +24,7 @@ from accelerate.utils import set_seed from accelerate import Accelerator from diffusers import DDPMScheduler from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from library import deepspeed_utils, model_util, strategy_base, strategy_sd +from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd import library.train_util as train_util from library.train_util import DreamBoothDataset @@ -1718,6 +1718,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) + sai_model_spec.add_model_spec_arguments(parser) parser.add_argument( "--cpu_offload_checkpointing",