mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #1168 from gesen2egee/save_state_on_train_end
Save state on train end
This commit is contained in:
@@ -457,7 +457,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state and is_main_process:
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|||||||
@@ -2938,6 +2938,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
|
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_state_on_train_end",
|
||||||
|
action="store_true",
|
||||||
|
help="save training state additionally (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを追加で保存する",
|
||||||
|
)
|
||||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||||
|
|
||||||
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
||||||
|
|||||||
@@ -712,7 +712,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state: # and is_main_process:
|
if args.save_state or args.save_state_on_train_end:
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|||||||
@@ -549,7 +549,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if is_main_process and args.save_state:
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
|||||||
@@ -565,7 +565,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if is_main_process and args.save_state:
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
|
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
|
||||||
|
|||||||
@@ -444,7 +444,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state and is_main_process:
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
del accelerator # この後メモリを使うのでこれは消す
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|||||||
@@ -940,7 +940,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if is_main_process and args.save_state:
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
|||||||
@@ -732,7 +732,7 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state and is_main_process:
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
|||||||
@@ -586,7 +586,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
if args.save_state and is_main_process:
|
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||||
|
|||||||
Reference in New Issue
Block a user