mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge 4d24b71c16 into 5462a6bb24
This commit is contained in:
@@ -854,7 +854,7 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||||
t = time_shift(mu, 1.0, t)
|
t = time_shift(mu, 1.0, t)
|
||||||
|
|
||||||
timesteps = t * 1000.0
|
timesteps = 1 - t * 1000.0
|
||||||
t = t.view(-1, 1, 1, 1)
|
t = t.view(-1, 1, 1, 1)
|
||||||
noisy_model_input = (1 - t) * noise + t * latents
|
noisy_model_input = (1 - t) * noise + t * latents
|
||||||
else:
|
else:
|
||||||
@@ -1060,6 +1060,7 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
default=1.0,
|
default=1.0,
|
||||||
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_prediction_type",
|
"--model_prediction_type",
|
||||||
choices=["raw", "additive", "sigma_scaled"],
|
choices=["raw", "additive", "sigma_scaled"],
|
||||||
|
|||||||
@@ -5525,6 +5525,9 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||||
|
|||||||
120
lumina_train.py
120
lumina_train.py
@@ -361,70 +361,78 @@ def train(args):
|
|||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
if args.blockwise_fused_optimizers:
|
# if args.blockwise_fused_optimizers:
|
||||||
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
# # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
||||||
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
# # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
|
||||||
# This balances memory usage and management complexity.
|
# # This balances memory usage and management complexity.
|
||||||
|
|
||||||
# split params into groups. currently different learning rates are not supported
|
# # split params into groups. currently different learning rates are not supported
|
||||||
grouped_params = []
|
# grouped_params = []
|
||||||
param_group = {}
|
# param_group = {}
|
||||||
for group in params_to_optimize:
|
# for group in params_to_optimize:
|
||||||
named_parameters = list(nextdit.named_parameters())
|
# named_parameters = list(nextdit.named_parameters())
|
||||||
assert len(named_parameters) == len(
|
# assert len(named_parameters) == len(
|
||||||
group["params"]
|
# group["params"]
|
||||||
), "number of parameters does not match"
|
# ), "number of parameters does not match"
|
||||||
for p, np in zip(group["params"], named_parameters):
|
# for p, np in zip(group["params"], named_parameters):
|
||||||
# determine target layer and block index for each parameter
|
# # determine target layer and block index for each parameter
|
||||||
block_type = "other" # double, single or other
|
# block_type = "other" # double, single or other
|
||||||
if np[0].startswith("double_blocks"):
|
# if np[0].startswith("double_blocks"):
|
||||||
block_index = int(np[0].split(".")[1])
|
# block_index = int(np[0].split(".")[1])
|
||||||
block_type = "double"
|
# block_type = "double"
|
||||||
elif np[0].startswith("single_blocks"):
|
# elif np[0].startswith("single_blocks"):
|
||||||
block_index = int(np[0].split(".")[1])
|
# block_index = int(np[0].split(".")[1])
|
||||||
block_type = "single"
|
# block_type = "single"
|
||||||
else:
|
# else:
|
||||||
block_index = -1
|
# block_index = -1
|
||||||
|
|
||||||
param_group_key = (block_type, block_index)
|
# param_group_key = (block_type, block_index)
|
||||||
if param_group_key not in param_group:
|
# if param_group_key not in param_group:
|
||||||
param_group[param_group_key] = []
|
# param_group[param_group_key] = []
|
||||||
param_group[param_group_key].append(p)
|
# param_group[param_group_key].append(p)
|
||||||
|
|
||||||
block_types_and_indices = []
|
# block_types_and_indices = []
|
||||||
for param_group_key, param_group in param_group.items():
|
# for param_group_key, param_group in param_group.items():
|
||||||
block_types_and_indices.append(param_group_key)
|
# block_types_and_indices.append(param_group_key)
|
||||||
grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
# grouped_params.append({"params": param_group, "lr": args.learning_rate})
|
||||||
|
|
||||||
num_params = 0
|
# num_params = 0
|
||||||
for p in param_group:
|
# for p in param_group:
|
||||||
num_params += p.numel()
|
# num_params += p.numel()
|
||||||
accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
# accelerator.print(f"block {param_group_key}: {num_params} parameters")
|
||||||
|
|
||||||
# prepare optimizers for each group
|
# # prepare optimizers for each group
|
||||||
optimizers = []
|
# optimizers = []
|
||||||
for group in grouped_params:
|
# for group in grouped_params:
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
# _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
||||||
optimizers.append(optimizer)
|
# optimizers.append(optimizer)
|
||||||
optimizer = optimizers[0] # avoid error in the following code
|
# optimizer = optimizers[0] # avoid error in the following code
|
||||||
|
|
||||||
logger.info(
|
# logger.info(
|
||||||
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
# f"using {len(optimizers)} optimizers for blockwise fused optimizers"
|
||||||
)
|
# )
|
||||||
|
|
||||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
# if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"Schedule-free optimizer is not supported with blockwise fused optimizers"
|
# "Schedule-free optimizer is not supported with blockwise fused optimizers"
|
||||||
)
|
# )
|
||||||
optimizer_train_fn = lambda: None # dummy function
|
# optimizer_train_fn = lambda: None # dummy function
|
||||||
optimizer_eval_fn = lambda: None # dummy function
|
# optimizer_eval_fn = lambda: None # dummy function
|
||||||
else:
|
# else:
|
||||||
_, _, optimizer = train_util.get_optimizer(
|
# _, _, optimizer = train_util.get_optimizer(
|
||||||
|
# args, trainable_params=params_to_optimize
|
||||||
|
# )
|
||||||
|
# optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||||
|
# optimizer, args
|
||||||
|
# )
|
||||||
|
|
||||||
|
#Currently when using blockwise_fused_optimizers the weight of model is not updated.
|
||||||
|
_, _, optimizer = train_util.get_optimizer(
|
||||||
args, trainable_params=params_to_optimize
|
args, trainable_params=params_to_optimize
|
||||||
)
|
)
|
||||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
|
||||||
optimizer, args
|
optimizer, args
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare dataloader
|
# prepare dataloader
|
||||||
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
|
||||||
@@ -743,7 +751,7 @@ def train(args):
|
|||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = nextdit(
|
model_pred = nextdit(
|
||||||
x=noisy_model_input, # image latents (B, C, H, W)
|
x=noisy_model_input, # image latents (B, C, H, W)
|
||||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||||
cap_mask=gemma2_attn_mask.to(
|
cap_mask=gemma2_attn_mask.to(
|
||||||
dtype=torch.int32
|
dtype=torch.int32
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||||
model_pred = dit(
|
model_pred = dit(
|
||||||
x=img, # image latents (B, C, H, W)
|
x=img, # image latents (B, C, H, W)
|
||||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user