mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add save/load hook to remove U-Net/TEs from state
This commit is contained in:
@@ -471,6 +471,31 @@ class NetworkTrainer:
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# before resuming make hook for saving/loading to save/load the network weights only
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# pop weights of other models than network to save only network weights
|
||||
if accelerator.is_main_process:
|
||||
remove_indices = []
|
||||
for i,model in enumerate(models):
|
||||
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
||||
remove_indices.append(i)
|
||||
for i in reversed(remove_indices):
|
||||
weights.pop(i)
|
||||
# print(f"save model hook: {len(weights)} weights will be saved")
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
# remove models except network
|
||||
remove_indices = []
|
||||
for i, model in enumerate(models):
|
||||
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
||||
remove_indices.append(i)
|
||||
for i in reversed(remove_indices):
|
||||
models.pop(i)
|
||||
# print(f"load model hook: {len(models)} models will be loaded")
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user