diff --git a/XTI_hijack.py b/XTI_hijack.py index f39cc8e7..36b5d3f2 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -2,132 +2,123 @@ import torch from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput -def unet_forward_XTI(self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. +from library.original_unet import SampleOutput - Returns: - [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None +def unet_forward_XTI( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, +) -> Union[Dict, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. - if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): - logger.info("Forward upsample size to force interpolation output size.") - forward_upsample_size = True + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**self.num_upsamplers - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) + # 1. time + timesteps = timestep + timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 - t_emb = self.time_proj(timesteps) + t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) - if self.config.num_class_embeds is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb + # 2. pre-process + sample = self.conv_in(sample) - # 2. pre-process - sample = self.conv_in(sample) + # 3. down + down_block_res_samples = (sample,) + down_i = 0 + for downsample_block in self.down_blocks: + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2], + ) + down_i += 2 + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - # 3. down - down_block_res_samples = (sample,) - down_i = 0 - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], - ) - down_i += 2 - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + down_block_res_samples += res_samples - down_block_res_samples += res_samples + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) - # 4. mid - sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) + # 5. up + up_i = 7 + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 - # 5. up - up_i = 7 - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3], + upsample_size=upsample_size, + ) + up_i += 3 + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], - upsample_size=upsample_size, - ) - up_i += 3 - else: - sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size - ) - # 6. post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) - if not return_dict: - return (sample,) + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) - return UNet2DConditionOutput(sample=sample) def downblock_forward_XTI( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None @@ -166,6 +157,7 @@ def downblock_forward_XTI( return hidden_states, output_states + def upblock_forward_XTI( self, hidden_states, @@ -199,11 +191,11 @@ def upblock_forward_XTI( else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample - + i += 1 if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/fine_tune.py b/fine_tune.py index 201d4952..881845c5 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,7 +91,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -385,8 +385,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -428,8 +428,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -437,8 +437,8 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/library/train_util.py b/library/train_util.py index 7d7eb325..3ae5d0f3 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2904,23 +2904,9 @@ def prepare_accelerator(args: argparse.Namespace): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=log_with, - logging_dir=logging_dir, + project_dir=logging_dir, ) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) - - return accelerator, unwrap_model + return accelerator def prepare_dtype(args: argparse.Namespace): diff --git a/train_db.py b/train_db.py index c81a092d..09f8d361 100644 --- a/train_db.py +++ b/train_db.py @@ -95,7 +95,7 @@ def train(args): f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" ) - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -372,8 +372,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -420,8 +420,8 @@ def train(args): epoch, num_train_epochs, global_step, - unwrap_model(text_encoder), - unwrap_model(unet), + accelerator.unwrap_model(text_encoder), + accelerator.unwrap_model(unet), vae, ) @@ -429,8 +429,8 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) + unet = accelerator.unwrap_model(unet) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/train_network.py b/train_network.py index b62aef7e..7c74ae5d 100644 --- a/train_network.py +++ b/train_network.py @@ -150,7 +150,7 @@ def train(args): # acceleratorを準備する print("preparing accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process # mixed precisionに対応した型を用意しておき適宜castする @@ -702,7 +702,7 @@ def train(args): accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model(ckpt_name, unwrap_model(network), global_step, epoch) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) if args.save_state: train_util.save_and_remove_state_stepwise(args, accelerator, global_step) @@ -744,7 +744,7 @@ def train(args): saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: @@ -762,7 +762,7 @@ def train(args): metadata["ss_training_finished_at"] = str(time.time()) if is_main_process: - network = unwrap_model(network) + network = accelerator.unwrap_model(network) accelerator.end_training() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 8be0703d..9dd846bd 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -98,7 +98,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -291,7 +291,7 @@ def train(args): index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) @@ -440,7 +440,7 @@ def train(args): # Let's make sure we don't update any embedding weights besides the newly added token with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ index_no_updates ] @@ -457,7 +457,9 @@ def train(args): if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + updated_embs = ( + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + ) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) save_model(ckpt_name, updated_embs, global_step, epoch) @@ -493,7 +495,7 @@ def train(args): accelerator.wait_for_everyone() - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs @@ -517,7 +519,7 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - text_encoder = unwrap_model(text_encoder) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7b734f28..1ea6dfc6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -11,6 +11,7 @@ import torch from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler +import library import library.train_util as train_util import library.huggingface_util as huggingface_util @@ -20,7 +21,14 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) +import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -98,7 +106,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) @@ -257,9 +265,9 @@ def train(args): # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + original_unet.UNet2DConditionModel.forward = unet_forward_XTI + original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI + original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI # 学習を準備する if cache_latents: @@ -319,7 +327,7 @@ def train(args): index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # print(len(index_no_updates), torch.sum(index_no_updates)) - orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() # Freeze all parameters except for the token embeddings in text encoder text_encoder.requires_grad_(True) @@ -473,7 +481,7 @@ def train(args): # Let's make sure we don't update any embedding weights besides the newly added token with torch.no_grad(): - unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ index_no_updates ] @@ -490,7 +498,13 @@ def train(args): if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + updated_embs = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[token_ids_XTI] + .data.detach() + .clone() + ) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) save_model(ckpt_name, updated_embs, global_step, epoch) @@ -526,7 +540,7 @@ def train(args): accelerator.wait_for_everyone() - updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() + updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() if args.save_every_n_epochs is not None: saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs @@ -551,7 +565,7 @@ def train(args): is_main_process = accelerator.is_main_process if is_main_process: - text_encoder = unwrap_model(text_encoder) + text_encoder = accelerator.unwrap_model(text_encoder) accelerator.end_training()