From cb53a7733415aaeb424f6e2e170e729414f5de9a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 30 Mar 2023 21:33:57 +0900 Subject: [PATCH] show warning message for sample images in XTI --- train_textual_inversion_XTI.py | 68 ++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 11 deletions(-) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8d6ff430..74e9bc2e 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -83,6 +83,11 @@ def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) + if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None: + print( + "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません" + ) + cache_latents = args.cache_latents if args.seed is not None: @@ -123,7 +128,24 @@ def train(args): assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}" token_strings_XTI = [] - XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] for layer_name in XTI_layers: token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] @@ -193,10 +215,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) - current_epoch = Value('i',0) - current_step = Value('i',0) + current_epoch = Value("i", 0) + current_step = Value("i", 0) ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: @@ -273,7 +295,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 @@ -350,7 +374,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 text_encoder.train() @@ -371,7 +395,12 @@ def train(args): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) # weight_dtype) use float instead of fp16/bf16 because text encoder is float - encoder_hidden_states = torch.stack([train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) for s in torch.split(input_ids, 1, dim=1)]) + encoder_hidden_states = torch.stack( + [ + train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype) + for s in torch.split(input_ids, 1, dim=1) + ] + ) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -398,9 +427,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - + if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -511,7 +540,24 @@ def train(args): def save_weights(file, updated_embs, save_dtype): updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1]) updated_embs = updated_embs.chunk(16) - XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] state_dict = {} for i, layer_name in enumerate(XTI_layers): state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype) @@ -540,7 +586,7 @@ def load_weights(file): if len(data.values()) != 16: raise ValueError(f"NOT XTI: {file}") - + emb = torch.concat([x for x in data.values()]) return emb