mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add grad_hook after restore state closes #1344
This commit is contained in:
@@ -481,6 +481,26 @@ def train(args):
|
|||||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
else:
|
||||||
|
# make sure Text Encoders are on GPU
|
||||||
|
text_encoder1.to(accelerator.device)
|
||||||
|
text_encoder2.to(accelerator.device)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
|
if args.full_fp16:
|
||||||
|
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||||
|
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
||||||
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
|
# resumeする
|
||||||
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
if args.fused_backward_pass:
|
if args.fused_backward_pass:
|
||||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||||
import library.adafactor_fused
|
import library.adafactor_fused
|
||||||
@@ -532,26 +552,6 @@ def train(args):
|
|||||||
parameter_optimizer_map[parameter] = opt_idx
|
parameter_optimizer_map[parameter] = opt_idx
|
||||||
num_parameters_per_group[opt_idx] += 1
|
num_parameters_per_group[opt_idx] += 1
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
|
||||||
if args.cache_text_encoder_outputs:
|
|
||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
else:
|
|
||||||
# make sure Text Encoders are on GPU
|
|
||||||
text_encoder1.to(accelerator.device)
|
|
||||||
text_encoder2.to(accelerator.device)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
|
||||||
if args.full_fp16:
|
|
||||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
|
||||||
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
|
||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
|
||||||
|
|
||||||
# resumeする
|
|
||||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
@@ -589,7 +589,11 @@ def train(args):
|
|||||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||||
if args.log_tracker_config is not None:
|
if args.log_tracker_config is not None:
|
||||||
init_kwargs = toml.load(args.log_tracker_config)
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
|
accelerator.init_trackers(
|
||||||
|
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||||
|
config=train_util.get_sanitized_config_or_none(args),
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
sdxl_train_util.sample_images(
|
sdxl_train_util.sample_images(
|
||||||
|
|||||||
Reference in New Issue
Block a user