diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5fd5e05e..35b4ede6 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -123,17 +123,17 @@ def train(args): if init_token_id is not None: for token_id in token_ids: token_embeds[token_id] = token_embeds[init_token_id] - print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # load weights if args.weights is not None: embeddings = load_weights(args.weights) assert len(token_ids) == len( embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}" - print(token_ids, embeddings.size()) + # print(token_ids, embeddings.size()) for token_id, embedding in zip(token_ids, embeddings): token_embeds[token_id] = embedding - print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) + # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) print(f"weighs loaded") print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") @@ -215,10 +215,15 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * len(train_dataloader) + print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) @@ -263,6 +268,8 @@ def train(args): # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -367,7 +374,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() @@ -392,9 +399,9 @@ def train(args): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) - saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: - train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) # end of epoch @@ -448,7 +455,6 @@ def load_weights(file): data = torch.load(file, map_location='cpu') if type(data) != dict: raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}") - print(data.keys()) if 'string_to_param' in data: # textual inversion embeddings data = data['string_to_param']