mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
persistent_workersを有効にした際にキャプションが変化しなくなるバグ修正
This commit is contained in:
@@ -62,6 +62,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
@@ -259,13 +260,13 @@ def train(args):
|
|||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
train_dataset_group.set_current_epoch(epoch + 1)
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
|
|
||||||
for m in training_models:
|
for m in training_models:
|
||||||
m.train()
|
m.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
@@ -497,6 +497,14 @@ def load_user_config(file: str) -> dict:
|
|||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
def blueprint_args_conflict(args,blueprint:Blueprint):
|
||||||
|
# train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする
|
||||||
|
for b in blueprint.dataset_group.datasets:
|
||||||
|
for t in b.subsets:
|
||||||
|
if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0):
|
||||||
|
print("Warning: %s: caption_dropout_every_n_epochs and token_warmup_step is ignored because --persistent_data_loader_workers option is used / --persistent_data_loader_workersオプションが使われているため、caption_dropout_every_n_epochs及びtoken_warmup_stepは無視されます。"%(t.params.image_dir))
|
||||||
|
t.params.caption_dropout_every_n_epochs = 0
|
||||||
|
t.params.token_warmup_step = 0
|
||||||
|
|
||||||
# for config test
|
# for config test
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
if args.no_token_padding:
|
if args.no_token_padding:
|
||||||
@@ -233,6 +234,7 @@ def train(args):
|
|||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
train_dataset_group.set_current_epoch(epoch + 1)
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
|
|
||||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||||
unet.train()
|
unet.train()
|
||||||
@@ -241,7 +243,6 @@ def train(args):
|
|||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
# 指定したステップ数でText Encoderの学習を止める
|
# 指定したステップ数でText Encoderの学習を止める
|
||||||
if global_step == args.stop_text_encoder_training:
|
if global_step == args.stop_text_encoder_training:
|
||||||
print(f"stop text encoder training at step {global_step}")
|
print(f"stop text encoder training at step {global_step}")
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
@@ -501,13 +502,13 @@ def train(args):
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
train_dataset_group.set_current_epoch(epoch + 1)
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch + 1)
|
metadata["ss_epoch"] = str(epoch + 1)
|
||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
@@ -183,6 +183,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
@@ -335,12 +336,12 @@ def train(args):
|
|||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
train_dataset_group.set_current_epoch(epoch + 1)
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user