mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
show warning message for sample images in XTI
This commit is contained in:
@@ -83,6 +83,11 @@ def train(args):
|
|||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
|
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||||
|
print(
|
||||||
|
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||||
|
)
|
||||||
|
|
||||||
cache_latents = args.cache_latents
|
cache_latents = args.cache_latents
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
@@ -123,7 +128,24 @@ def train(args):
|
|||||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
token_strings_XTI = []
|
token_strings_XTI = []
|
||||||
XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
|
XTI_layers = [
|
||||||
|
"IN01",
|
||||||
|
"IN02",
|
||||||
|
"IN04",
|
||||||
|
"IN05",
|
||||||
|
"IN07",
|
||||||
|
"IN08",
|
||||||
|
"MID",
|
||||||
|
"OUT03",
|
||||||
|
"OUT04",
|
||||||
|
"OUT05",
|
||||||
|
"OUT06",
|
||||||
|
"OUT07",
|
||||||
|
"OUT08",
|
||||||
|
"OUT09",
|
||||||
|
"OUT10",
|
||||||
|
"OUT11",
|
||||||
|
]
|
||||||
for layer_name in XTI_layers:
|
for layer_name in XTI_layers:
|
||||||
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
||||||
|
|
||||||
@@ -193,8 +215,8 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||||
current_epoch = Value('i',0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value('i',0)
|
current_step = Value("i", 0)
|
||||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||||
|
|
||||||
@@ -273,7 +295,9 @@ def train(args):
|
|||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
@@ -371,7 +395,12 @@ def train(args):
|
|||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||||
encoder_hidden_states = torch.stack([train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) for s in torch.split(input_ids, 1, dim=1)])
|
encoder_hidden_states = torch.stack(
|
||||||
|
[
|
||||||
|
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
|
||||||
|
for s in torch.split(input_ids, 1, dim=1)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
@@ -511,7 +540,24 @@ def train(args):
|
|||||||
def save_weights(file, updated_embs, save_dtype):
|
def save_weights(file, updated_embs, save_dtype):
|
||||||
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
|
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
|
||||||
updated_embs = updated_embs.chunk(16)
|
updated_embs = updated_embs.chunk(16)
|
||||||
XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
|
XTI_layers = [
|
||||||
|
"IN01",
|
||||||
|
"IN02",
|
||||||
|
"IN04",
|
||||||
|
"IN05",
|
||||||
|
"IN07",
|
||||||
|
"IN08",
|
||||||
|
"MID",
|
||||||
|
"OUT03",
|
||||||
|
"OUT04",
|
||||||
|
"OUT05",
|
||||||
|
"OUT06",
|
||||||
|
"OUT07",
|
||||||
|
"OUT08",
|
||||||
|
"OUT09",
|
||||||
|
"OUT10",
|
||||||
|
"OUT11",
|
||||||
|
]
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
for i, layer_name in enumerate(XTI_layers):
|
for i, layer_name in enumerate(XTI_layers):
|
||||||
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
|
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user