mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
typo修正、stepをglobal_stepに修正、バグ修正
This commit is contained in:
@@ -265,8 +265,8 @@ def train(args):
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
train_dataset_group.set_current_step(global_step)
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
train_dataset_group.set_current_step(step + 1)
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
|
||||
@@ -57,7 +57,7 @@ class BaseSubsetParams:
|
||||
caption_dropout_every_n_epochs: int = 0
|
||||
caption_tag_dropout_rate: float = 0.0
|
||||
token_warmup_min: int = 1
|
||||
token_warmup_step: Union[float,int] = 0
|
||||
token_warmup_step: float = 0
|
||||
|
||||
@dataclass
|
||||
class DreamBoothSubsetParams(BaseSubsetParams):
|
||||
@@ -140,7 +140,7 @@ class ConfigSanitizer:
|
||||
"shuffle_caption": bool,
|
||||
"keep_tokens": int,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Union[float,int],
|
||||
"token_warmup_step": Any(float,int),
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
|
||||
@@ -2046,7 +2046,7 @@ def add_dataset_arguments(
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token_warmup_steps",
|
||||
"--token_warmup_step",
|
||||
type=float,
|
||||
default=0,
|
||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
|
||||
@@ -241,6 +241,7 @@ def train(args):
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
train_dataset_group.set_current_step(global_step)
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
print(f"stop text encoder training at step {global_step}")
|
||||
@@ -249,7 +250,6 @@ def train(args):
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
train_dataset_group.set_current_step(step + 1)
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
if cache_latents:
|
||||
|
||||
@@ -507,8 +507,8 @@ def train(args):
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
train_dataset_group.set_current_step(global_step)
|
||||
with accelerator.accumulate(network):
|
||||
train_dataset_group.set_current_step(step + 1)
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
|
||||
@@ -340,8 +340,8 @@ def train(args):
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
train_dataset_group.set_current_step(global_step)
|
||||
with accelerator.accumulate(text_encoder):
|
||||
train_dataset_group.set_current_step(step + 1)
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
|
||||
Reference in New Issue
Block a user