From 534059dea517d44de387e7d467d64209f9dcfba2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 3 Jan 2025 01:18:15 -0500 Subject: [PATCH] Typos and lingering is_train --- library/config_util.py | 2 +- library/train_util.py | 4 ---- train_network.py | 6 +++--- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index a09d2c7c..418c179d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -535,7 +535,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} caption_dropout_rate: {subset.caption_dropout_rate} - caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs} + caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs} caption_tag_dropout_rate: {subset.caption_tag_dropout_rate} caption_prefix: {subset.caption_prefix} caption_suffix: {subset.caption_suffix} diff --git a/library/train_util.py b/library/train_util.py index bf1b6731..220d4702 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2092,7 +2092,6 @@ class FineTuningDataset(BaseDataset): bucket_reso_steps: int, bucket_no_upscale: bool, debug_dataset: bool, - is_train: bool, validation_seed: int, validation_split: float, ) -> None: @@ -2312,7 +2311,6 @@ class ControlNetDataset(BaseDataset): def __init__( self, subsets: Sequence[ControlNetSubset], - is_train: bool, batch_size: int, resolution, network_multiplier: float, @@ -2362,7 +2360,6 @@ class ControlNetDataset(BaseDataset): self.dreambooth_dataset_delegate = DreamBoothDataset( db_subsets, - is_train, batch_size, resolution, network_multiplier, @@ -2382,7 +2379,6 @@ class ControlNetDataset(BaseDataset): self.batch_size = batch_size self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images - self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed diff --git a/train_network.py b/train_network.py index 99b9717a..4bcfc0ac 100644 --- a/train_network.py +++ b/train_network.py @@ -380,11 +380,11 @@ class NetworkTrainer: else: return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device)) - choosen_timesteps_list = pick_timesteps_list() + chosen_timesteps_list = pick_timesteps_list() total_loss = torch.zeros((batch_size, 1)).to(latents.device) # Use input timesteps_list or use described timesteps above - for fixed_timestep in choosen_timesteps_list: + for fixed_timestep in chosen_timesteps_list: fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep) # Predict the noise residual @@ -447,7 +447,7 @@ class NetworkTrainer: total_loss += loss - return total_loss / len(choosen_timesteps_list) + return total_loss / len(chosen_timesteps_list) def train(self, args): session_id = random.randint(0, 2**32)