mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
show warning message for sample images in XTI
This commit is contained in:
@@ -83,6 +83,11 @@ def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||
print(
|
||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||
)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
if args.seed is not None:
|
||||
@@ -123,7 +128,24 @@ def train(args):
|
||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||
|
||||
token_strings_XTI = []
|
||||
XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
|
||||
XTI_layers = [
|
||||
"IN01",
|
||||
"IN02",
|
||||
"IN04",
|
||||
"IN05",
|
||||
"IN07",
|
||||
"IN08",
|
||||
"MID",
|
||||
"OUT03",
|
||||
"OUT04",
|
||||
"OUT05",
|
||||
"OUT06",
|
||||
"OUT07",
|
||||
"OUT08",
|
||||
"OUT09",
|
||||
"OUT10",
|
||||
"OUT11",
|
||||
]
|
||||
for layer_name in XTI_layers:
|
||||
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
||||
|
||||
@@ -193,10 +215,10 @@ def train(args):
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||
current_epoch = Value('i',0)
|
||||
current_step = Value('i',0)
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
|
||||
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
@@ -273,7 +295,9 @@ def train(args):
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
@@ -350,7 +374,7 @@ def train(args):
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch+1
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
text_encoder.train()
|
||||
|
||||
@@ -371,7 +395,12 @@ def train(args):
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||
encoder_hidden_states = torch.stack([train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) for s in torch.split(input_ids, 1, dim=1)])
|
||||
encoder_hidden_states = torch.stack(
|
||||
[
|
||||
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
|
||||
for s in torch.split(input_ids, 1, dim=1)
|
||||
]
|
||||
)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
@@ -511,7 +540,24 @@ def train(args):
|
||||
def save_weights(file, updated_embs, save_dtype):
|
||||
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
|
||||
updated_embs = updated_embs.chunk(16)
|
||||
XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
|
||||
XTI_layers = [
|
||||
"IN01",
|
||||
"IN02",
|
||||
"IN04",
|
||||
"IN05",
|
||||
"IN07",
|
||||
"IN08",
|
||||
"MID",
|
||||
"OUT03",
|
||||
"OUT04",
|
||||
"OUT05",
|
||||
"OUT06",
|
||||
"OUT07",
|
||||
"OUT08",
|
||||
"OUT09",
|
||||
"OUT10",
|
||||
"OUT11",
|
||||
]
|
||||
state_dict = {}
|
||||
for i, layer_name in enumerate(XTI_layers):
|
||||
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
|
||||
|
||||
Reference in New Issue
Block a user