typo修正、stepをglobal_stepに修正、バグ修正

This commit is contained in:
u-haru
2023-03-23 09:53:14 +09:00
parent a9b26b73e0
commit 447c56bf50
6 changed files with 7 additions and 7 deletions

View File

@@ -265,8 +265,8 @@ def train(args):
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)

View File

@@ -57,7 +57,7 @@ class BaseSubsetParams:
caption_dropout_every_n_epochs: int = 0 caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0 caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1 token_warmup_min: int = 1
token_warmup_step: Union[float,int] = 0 token_warmup_step: float = 0
@dataclass @dataclass
class DreamBoothSubsetParams(BaseSubsetParams): class DreamBoothSubsetParams(BaseSubsetParams):
@@ -140,7 +140,7 @@ class ConfigSanitizer:
"shuffle_caption": bool, "shuffle_caption": bool,
"keep_tokens": int, "keep_tokens": int,
"token_warmup_min": int, "token_warmup_min": int,
"token_warmup_step": Union[float,int], "token_warmup_step": Any(float,int),
} }
# DO means DropOut # DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = { DO_SUBSET_ASCENDABLE_SCHEMA = {

View File

@@ -2046,7 +2046,7 @@ def add_dataset_arguments(
) )
parser.add_argument( parser.add_argument(
"--token_warmup_steps", "--token_warmup_step",
type=float, type=float,
default=0, default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大", help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",

View File

@@ -241,6 +241,7 @@ def train(args):
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
# 指定したステップ数でText Encoderの学習を止める # 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training: if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}") print(f"stop text encoder training at step {global_step}")
@@ -249,7 +250,6 @@ def train(args):
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
# latentに変換 # latentに変換
if cache_latents: if cache_latents:

View File

@@ -507,8 +507,8 @@ def train(args):
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
with accelerator.accumulate(network): with accelerator.accumulate(network):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)

View File

@@ -340,8 +340,8 @@ def train(args):
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
train_dataset_group.set_current_step(global_step)
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)