@@ -1898,12 +1898,28 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
def add_training_arguments ( parser : argparse . ArgumentParser , support_dreambooth : bool ) :
parser . add_argument ( " --output_dir " , type = str , default = None , help = " directory to output trained model / 学習後のモデル出力先ディレクトリ " )
parser . add_argument ( " --output_name " , type = str , default = None , help = " base name of trained model file / 学習後のモデルの拡張子を除くファイル名 " )
parser . add_argument ( " --huggingface_repo_id " , type = str , default = None , help = " huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名 " )
parser . add_argument ( " --huggingface_repo_type " , type = str , default = None , help = " huggingface repo typ e to upload / huggingfaceにアップロードするリポジトリの種類 " )
parser . add_argument ( " --huggingface_path_in_repo " , type = str , default = None , help = " huggingface model path to upload files / huggingfaceにアップロードするファイルのパス " )
parser . add_argument (
" --huggingface_repo_id " , type = str , default = None , help = " huggingface repo nam e to upload / huggingfaceにアップロードするリポジトリ名 "
)
parser . add_argument (
" --huggingface_repo_type " , type = str , default = None , help = " huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類 "
)
parser . add_argument (
" --huggingface_path_in_repo " ,
type = str ,
default = None ,
help = " huggingface model path to upload files / huggingfaceにアップロードするファイルのパス " ,
)
parser . add_argument ( " --huggingface_token " , type = str , default = None , help = " huggingface token / huggingfaceのトークン " )
parser . add_argument ( " --huggingface_repo_visibility " , type = str , default = None , help = " huggingface repository visibility / huggingfaceにアップロードするリポジトリの公開設定 " )
parser . add_argument ( " --save_state_to_huggingface " , action = " store_true " , help = " save state to huggingface / huggingfaceにstateを保存する " )
parser . add_argument (
" --huggingface_repo_visibility " ,
type = str ,
default = None ,
help = " huggingface repository visibility ( ' public ' for public, ' private ' or None for private) / huggingfaceにアップロードするリポジトリの公開設定( ' public ' で公開、 ' private ' またはNoneで非公開) " ,
)
parser . add_argument (
" --save_state_to_huggingface " , action = " store_true " , help = " save state to huggingface / huggingfaceにstateを保存する "
)
parser . add_argument (
" --resume_from_huggingface " ,
action = " store_true " ,
@@ -2278,10 +2294,17 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
# region utils
def resume ( accelerator , args ) :
if args . resume :
print ( f " resume training from state: { args . resume } " )
if args . resume_from_huggingface :
def resume_from_local_or_hf_if_specified ( accelerator , args ) :
if not args . resume :
return
if not args . resume_from_huggingface :
print ( f " resume training from local state: { args . resume } " )
accelerator . load_state ( args . resume )
return
print ( f " resume training from huggingface state: { args . resume } " )
repo_id = args . resume . split ( " / " ) [ 0 ] + " / " + args . resume . split ( " / " ) [ 1 ]
path_in_repo = " / " . join ( args . resume . split ( " / " ) [ 2 : ] )
revision = None
@@ -2293,9 +2316,7 @@ def resume(accelerator, args):
repo_type = " model "
else :
path_in_repo , revision , repo_type = divided
print (
f " Downloading state from huggingface: { repo_id } / { path_in_repo } @ { revision } "
)
print( f " Downloading state from huggingface: { repo_id } / { path_in_repo } @ { revision } " )
list_files = huggingface_util . list_dir (
repo_id = repo_id ,
@@ -2318,15 +2339,11 @@ def resume(accelerator, args):
return await asyncio . get_event_loop ( ) . run_in_executor ( None , task )
loop = asyncio . get_event_loop ( )
results = loop . run_until_complete (
asyncio . gather (
* [ download ( filename = filename . rfilename ) for filename in list_files ]
)
)
results = loop . run_until_complete ( asyncio . gather ( * [ download ( filename = filename . rfilename ) for filename in list_files ] ) )
if len ( results ) == 0 :
raise ValueError ( " No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした " )
dirname = os . path . dirname ( results [ 0 ] )
accelerator . load_state ( dirname )
else :
accelerator . load_state ( args . resume )
def get_optimizer ( args , trainable_params ) :
@@ -2713,7 +2730,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype , save_dtype
def load_target_model ( args : argparse . Namespace , weight_dtype , device = ' cpu' ) :
def load_target_model ( args : argparse . Namespace , weight_dtype , device = " cpu" ) :
name_or_path = args . pretrained_model_name_or_path
name_or_path = os . readlink ( name_or_path ) if os . path . islink ( name_or_path ) else name_or_path
load_stable_diffusion_format = os . path . isfile ( name_or_path ) # determine SD or Diffusers
@@ -2883,6 +2900,7 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
state_dir = os . path . join ( args . output_dir , EPOCH_STATE_NAME . format ( model_name , epoch_no ) )
accelerator . save_state ( state_dir )
if args . save_state_to_huggingface :
print ( " uploading state to huggingface. " )
huggingface_util . upload ( args , state_dir , " / " + EPOCH_STATE_NAME . format ( model_name , epoch_no ) )
last_n_epochs = args . save_last_n_epochs_state if args . save_last_n_epochs_state else args . save_last_n_epochs
@@ -2894,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
shutil . rmtree ( state_dir_old )
def save_state_on_train_end ( args : argparse . Namespace , accelerator ) :
print ( " saving last state. " )
os . makedirs ( args . output_dir , exist_ok = True )
model_name = DEFAULT_LAST_OUTPUT_NAME if args . output_name is None else args . output_name
state_dir = os . path . join ( args . output_dir , LAST_STATE_NAME . format ( model_name ) )
accelerator . save_state ( state_dir )
if args . save_state_to_huggingface :
print ( " uploading last state to huggingface. " )
huggingface_util . upload ( args , state_dir , " / " + LAST_STATE_NAME . format ( model_name ) )
def save_sd_model_on_train_end (
args : argparse . Namespace ,
src_path : str ,
@@ -2932,13 +2961,6 @@ def save_sd_model_on_train_end(
huggingface_util . upload ( args , out_dir , " / " + model_name , force_sync_upload = True )
def save_state_on_train_end ( args : argparse . Namespace , accelerator ) :
print ( " saving last state. " )
os . makedirs ( args . output_dir , exist_ok = True )
model_name = DEFAULT_LAST_OUTPUT_NAME if args . output_name is None else args . output_name
accelerator . save_state ( os . path . join ( args . output_dir , LAST_STATE_NAME . format ( model_name ) ) )
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120