mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
typo修正、stepをglobal_stepに修正、バグ修正
This commit is contained in:
@@ -265,8 +265,8 @@ def train(args):
|
|||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||||
train_dataset_group.set_current_step(step + 1)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class BaseSubsetParams:
|
|||||||
caption_dropout_every_n_epochs: int = 0
|
caption_dropout_every_n_epochs: int = 0
|
||||||
caption_tag_dropout_rate: float = 0.0
|
caption_tag_dropout_rate: float = 0.0
|
||||||
token_warmup_min: int = 1
|
token_warmup_min: int = 1
|
||||||
token_warmup_step: Union[float,int] = 0
|
token_warmup_step: float = 0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DreamBoothSubsetParams(BaseSubsetParams):
|
class DreamBoothSubsetParams(BaseSubsetParams):
|
||||||
@@ -140,7 +140,7 @@ class ConfigSanitizer:
|
|||||||
"shuffle_caption": bool,
|
"shuffle_caption": bool,
|
||||||
"keep_tokens": int,
|
"keep_tokens": int,
|
||||||
"token_warmup_min": int,
|
"token_warmup_min": int,
|
||||||
"token_warmup_step": Union[float,int],
|
"token_warmup_step": Any(float,int),
|
||||||
}
|
}
|
||||||
# DO means DropOut
|
# DO means DropOut
|
||||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||||
|
|||||||
@@ -2046,7 +2046,7 @@ def add_dataset_arguments(
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--token_warmup_steps",
|
"--token_warmup_step",
|
||||||
type=float,
|
type=float,
|
||||||
default=0,
|
default=0,
|
||||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||||
|
|||||||
@@ -241,6 +241,7 @@ def train(args):
|
|||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
# 指定したステップ数でText Encoderの学習を止める
|
# 指定したステップ数でText Encoderの学習を止める
|
||||||
if global_step == args.stop_text_encoder_training:
|
if global_step == args.stop_text_encoder_training:
|
||||||
print(f"stop text encoder training at step {global_step}")
|
print(f"stop text encoder training at step {global_step}")
|
||||||
@@ -249,7 +250,6 @@ def train(args):
|
|||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
with accelerator.accumulate(unet):
|
with accelerator.accumulate(unet):
|
||||||
train_dataset_group.set_current_step(step + 1)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# latentに変換
|
# latentに変換
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
|
|||||||
@@ -507,8 +507,8 @@ def train(args):
|
|||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
train_dataset_group.set_current_step(step + 1)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
|||||||
@@ -340,8 +340,8 @@ def train(args):
|
|||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
train_dataset_group.set_current_step(global_step)
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
train_dataset_group.set_current_step(step + 1)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user