mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Update lumina_train.py
This commit is contained in:
112
lumina_train.py
112
lumina_train.py
@@ -361,65 +361,73 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# This balances memory usage and management complexity.
|
||||
# if args.blockwise_fused_optimizers:
|
||||
# # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||
# # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||
# # This balances memory usage and management complexity.
|
||||
|
||||
# split params into groups. currently different learning rates are not supported
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
for group in params_to_optimize:
|
||||
named_parameters = list(nextdit.named_parameters())
|
||||
assert len(named_parameters) == len(
|
||||
group["params"]
|
||||
), "number of parameters does not match"
|
||||
for p, np in zip(group["params"], named_parameters):
|
||||
# determine target layer and block index for each parameter
|
||||
block_type = "other" # double, single or other
|
||||
if np[0].startswith("double_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "double"
|
||||
elif np[0].startswith("single_blocks"):
|
||||
block_index = int(np[0].split(".")[1])
|
||||
block_type = "single"
|
||||
else:
|
||||
block_index = -1
|
||||
# # split params into groups. currently different learning rates are not supported
|
||||
# grouped_params = []
|
||||
# param_group = {}
|
||||
# for group in params_to_optimize:
|
||||
# named_parameters = list(nextdit.named_parameters())
|
||||
# assert len(named_parameters) == len(
|
||||
# group["params"]
|
||||
# ), "number of parameters does not match"
|
||||
# for p, np in zip(group["params"], named_parameters):
|
||||
# # determine target layer and block index for each parameter
|
||||
# block_type = "other" # double, single or other
|
||||
# if np[0].startswith("double_blocks"):
|
||||
# block_index = int(np[0].split(".")[1])
|
||||
# block_type = "double"
|
||||
# elif np[0].startswith("single_blocks"):
|
||||
# block_index = int(np[0].split(".")[1])
|
||||
# block_type = "single"
|
||||
# else:
|
||||
# block_index = -1
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
# param_group_key = (block_type, block_index)
|
||||
# if param_group_key not in param_group:
|
||||
# param_group[param_group_key] = []
|
||||
# param_group[param_group_key].append(p)
|
||||
|
||||
block_types_and_indices = []
|
||||
for param_group_key, param_group in param_group.items():
|
||||
block_types_and_indices.append(param_group_key)
|
||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
# block_types_and_indices = []
|
||||
# for param_group_key, param_group in param_group.items():
|
||||
# block_types_and_indices.append(param_group_key)
|
||||
# grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||
|
||||
num_params = 0
|
||||
for p in param_group:
|
||||
num_params += p.numel()
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
# num_params = 0
|
||||
# for p in param_group:
|
||||
# num_params += p.numel()
|
||||
# accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||
|
||||
# prepare optimizers for each group
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(optimizer)
|
||||
optimizer = optimizers[0] # avoid error in the following code
|
||||
# # prepare optimizers for each group
|
||||
# optimizers = []
|
||||
# for group in grouped_params:
|
||||
# _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||
# optimizers.append(optimizer)
|
||||
# optimizer = optimizers[0] # avoid error in the following code
|
||||
|
||||
logger.info(
|
||||
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
||||
)
|
||||
# logger.info(
|
||||
# f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
||||
# )
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError(
|
||||
"Schedule-free optimizer is not supported with blockwise fused optimizers"
|
||||
)
|
||||
optimizer_train_fn = lambda: None # dummy function
|
||||
optimizer_eval_fn = lambda: None # dummy function
|
||||
else:
|
||||
_, _, optimizer = train_util.get_optimizer(
|
||||
# if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
# raise ValueError(
|
||||
# "Schedule-free optimizer is not supported with blockwise fused optimizers"
|
||||
# )
|
||||
# optimizer_train_fn = lambda: None # dummy function
|
||||
# optimizer_eval_fn = lambda: None # dummy function
|
||||
# else:
|
||||
# _, _, optimizer = train_util.get_optimizer(
|
||||
# args, trainable_params=params_to_optimize
|
||||
# )
|
||||
# optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||
# optimizer, args
|
||||
# )
|
||||
|
||||
#Currently when using blockwise_fused_optimizers the weight of model is not updated.
|
||||
_, _, optimizer = train_util.get_optimizer(
|
||||
args, trainable_params=params_to_optimize
|
||||
)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||
|
||||
Reference in New Issue
Block a user