mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add stochastic rounding, fix single block
This commit is contained in:
@@ -277,7 +277,10 @@ def train(args):
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
training_models.append(flux)
|
||||
params_to_optimize.append({"params": list(flux.parameters()), "lr": args.learning_rate})
|
||||
name_and_params = list(flux.named_parameters())
|
||||
# single param group for now
|
||||
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
|
||||
param_names = [[n for n, _ in name_and_params]]
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
@@ -433,17 +436,89 @@ def train(args):
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
for param_group in optimizer.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
|
||||
double_blocks_to_swap = args.double_blocks_to_swap
|
||||
single_blocks_to_swap = args.single_blocks_to_swap
|
||||
num_double_blocks = len(flux.double_blocks)
|
||||
num_single_blocks = len(flux.single_blocks)
|
||||
handled_double_block_indices = set()
|
||||
handled_single_block_indices = set()
|
||||
|
||||
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
|
||||
for parameter, param_name in zip(param_group["params"], param_name_group):
|
||||
if parameter.requires_grad:
|
||||
grad_hook = None
|
||||
|
||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
if double_blocks_to_swap:
|
||||
if param_name.startswith("double_blocks"):
|
||||
block_idx = int(param_name.split(".")[1])
|
||||
if (
|
||||
block_idx not in handled_double_block_indices
|
||||
and block_idx >= (num_double_blocks - double_blocks_to_swap) - 1
|
||||
and block_idx < num_double_blocks - 1
|
||||
):
|
||||
# swap next (already backpropagated) block
|
||||
handled_double_block_indices.add(block_idx)
|
||||
block_idx_cpu = block_idx + 1
|
||||
block_idx_cuda = double_blocks_to_swap - (num_double_blocks - block_idx_cpu)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
||||
# create swap hook
|
||||
def create_double_swap_grad_hook(bidx, bidx_cuda):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
# swap blocks if necessary
|
||||
flux.double_blocks[bidx].to("cpu")
|
||||
flux.double_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move double block {bidx} to cpu and {bidx_cuda} to device")
|
||||
|
||||
return __grad_hook
|
||||
|
||||
grad_hook = create_double_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
||||
if single_blocks_to_swap:
|
||||
if param_name.startswith("single_blocks"):
|
||||
block_idx = int(param_name.split(".")[1])
|
||||
if (
|
||||
block_idx not in handled_single_block_indices
|
||||
and block_idx >= (num_single_blocks - single_blocks_to_swap) - 1
|
||||
and block_idx < num_single_blocks - 1
|
||||
):
|
||||
handled_single_block_indices.add(block_idx)
|
||||
block_idx_cpu = block_idx + 1
|
||||
block_idx_cuda = single_blocks_to_swap - (num_single_blocks - block_idx_cpu)
|
||||
# print(param_name, block_idx_cpu, block_idx_cuda)
|
||||
|
||||
# create swap hook
|
||||
def create_single_swap_grad_hook(bidx, bidx_cuda):
|
||||
def __grad_hook(tensor: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
# swap blocks if necessary
|
||||
flux.single_blocks[bidx].to("cpu")
|
||||
flux.single_blocks[bidx_cuda].to(accelerator.device)
|
||||
# print(f"Move single block {bidx} to cpu and {bidx_cuda} to device")
|
||||
|
||||
return __grad_hook
|
||||
|
||||
grad_hook = create_single_swap_grad_hook(block_idx_cpu, block_idx_cuda)
|
||||
|
||||
if grad_hook is None:
|
||||
|
||||
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||
optimizer.step_param(tensor, param_group)
|
||||
tensor.grad = None
|
||||
|
||||
grad_hook = __grad_hook
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# prepare for additional optimizers and lr schedulers
|
||||
|
||||
Reference in New Issue
Block a user