mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Typos and lingering is_train
This commit is contained in:
@@ -535,7 +535,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
caption_prefix: {subset.caption_prefix}
|
||||
caption_suffix: {subset.caption_suffix}
|
||||
|
||||
@@ -2092,7 +2092,6 @@ class FineTuningDataset(BaseDataset):
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
is_train: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
) -> None:
|
||||
@@ -2312,7 +2311,6 @@ class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
self,
|
||||
subsets: Sequence[ControlNetSubset],
|
||||
is_train: bool,
|
||||
batch_size: int,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
@@ -2362,7 +2360,6 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||
db_subsets,
|
||||
is_train,
|
||||
batch_size,
|
||||
resolution,
|
||||
network_multiplier,
|
||||
@@ -2382,7 +2379,6 @@ class ControlNetDataset(BaseDataset):
|
||||
self.batch_size = batch_size
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.is_train = is_train
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
|
||||
|
||||
@@ -380,11 +380,11 @@ class NetworkTrainer:
|
||||
else:
|
||||
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))
|
||||
|
||||
choosen_timesteps_list = pick_timesteps_list()
|
||||
chosen_timesteps_list = pick_timesteps_list()
|
||||
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
|
||||
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
for fixed_timestep in choosen_timesteps_list:
|
||||
for fixed_timestep in chosen_timesteps_list:
|
||||
fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep)
|
||||
|
||||
# Predict the noise residual
|
||||
@@ -447,7 +447,7 @@ class NetworkTrainer:
|
||||
|
||||
total_loss += loss
|
||||
|
||||
return total_loss / len(choosen_timesteps_list)
|
||||
return total_loss / len(chosen_timesteps_list)
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
|
||||
Reference in New Issue
Block a user