mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
remove workaround for accelerator=0.15, fix XTI
This commit is contained in:
@@ -2,75 +2,62 @@ import torch
|
|||||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||||
|
|
||||||
def unet_forward_XTI(self,
|
from library.original_unet import SampleOutput
|
||||||
|
|
||||||
|
|
||||||
|
def unet_forward_XTI(
|
||||||
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
class_labels: Optional[torch.Tensor] = None,
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
) -> Union[Dict, Tuple]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
Whether or not to return a dict instead of a plain tuple.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
`SampleOutput` or `tuple`:
|
||||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
returning a tuple, the first element is the sample tensor.
|
|
||||||
"""
|
"""
|
||||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
# on the fly if necessary.
|
# on the fly if necessary.
|
||||||
|
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
||||||
|
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
||||||
|
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
||||||
default_overall_up_factor = 2**self.num_upsamplers
|
default_overall_up_factor = 2**self.num_upsamplers
|
||||||
|
|
||||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
# 64で割り切れないときはupsamplerにサイズを伝える
|
||||||
forward_upsample_size = False
|
forward_upsample_size = False
|
||||||
upsample_size = None
|
upsample_size = None
|
||||||
|
|
||||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||||
logger.info("Forward upsample size to force interpolation output size.")
|
# logger.info("Forward upsample size to force interpolation output size.")
|
||||||
forward_upsample_size = True
|
forward_upsample_size = True
|
||||||
|
|
||||||
# 0. center input if necessary
|
|
||||||
if self.config.center_input_sample:
|
|
||||||
sample = 2 * sample - 1.0
|
|
||||||
|
|
||||||
# 1. time
|
# 1. time
|
||||||
timesteps = timestep
|
timesteps = timestep
|
||||||
if not torch.is_tensor(timesteps):
|
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
||||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
|
||||||
# This would be a good case for the `match` statement (Python 3.10+)
|
|
||||||
is_mps = sample.device.type == "mps"
|
|
||||||
if isinstance(timestep, float):
|
|
||||||
dtype = torch.float32 if is_mps else torch.float64
|
|
||||||
else:
|
|
||||||
dtype = torch.int32 if is_mps else torch.int64
|
|
||||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
|
||||||
elif len(timesteps.shape) == 0:
|
|
||||||
timesteps = timesteps[None].to(sample.device)
|
|
||||||
|
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
||||||
timesteps = timesteps.expand(sample.shape[0])
|
|
||||||
|
|
||||||
t_emb = self.time_proj(timesteps)
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
# timesteps does not contain any weights and will always return f32 tensors
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
# there might be better ways to encapsulate this.
|
# there might be better ways to encapsulate this.
|
||||||
|
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
||||||
|
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
||||||
|
# time_projでキャストしておけばいいんじゃね?
|
||||||
t_emb = t_emb.to(dtype=self.dtype)
|
t_emb = t_emb.to(dtype=self.dtype)
|
||||||
emb = self.time_embedding(t_emb)
|
emb = self.time_embedding(t_emb)
|
||||||
|
|
||||||
if self.config.num_class_embeds is not None:
|
|
||||||
if class_labels is None:
|
|
||||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
|
||||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
|
||||||
emb = emb + class_emb
|
|
||||||
|
|
||||||
# 2. pre-process
|
# 2. pre-process
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
@@ -78,11 +65,13 @@ def unet_forward_XTI(self,
|
|||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
down_i = 0
|
down_i = 0
|
||||||
for downsample_block in self.down_blocks:
|
for downsample_block in self.down_blocks:
|
||||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||||
|
# まあこちらのほうがわかりやすいかもしれない
|
||||||
|
if downsample_block.has_cross_attention:
|
||||||
sample, res_samples = downsample_block(
|
sample, res_samples = downsample_block(
|
||||||
hidden_states=sample,
|
hidden_states=sample,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
|
encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
|
||||||
)
|
)
|
||||||
down_i += 2
|
down_i += 2
|
||||||
else:
|
else:
|
||||||
@@ -99,19 +88,19 @@ def unet_forward_XTI(self,
|
|||||||
is_final_block = i == len(self.up_blocks) - 1
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
||||||
|
|
||||||
# if we have not reached the final block and need to forward the
|
# if we have not reached the final block and need to forward the upsample size, we do it here
|
||||||
# upsample size, we do it here
|
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
||||||
if not is_final_block and forward_upsample_size:
|
if not is_final_block and forward_upsample_size:
|
||||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
if upsample_block.has_cross_attention:
|
||||||
sample = upsample_block(
|
sample = upsample_block(
|
||||||
hidden_states=sample,
|
hidden_states=sample,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
res_hidden_states_tuple=res_samples,
|
res_hidden_states_tuple=res_samples,
|
||||||
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
|
encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
|
||||||
upsample_size=upsample_size,
|
upsample_size=upsample_size,
|
||||||
)
|
)
|
||||||
up_i += 3
|
up_i += 3
|
||||||
@@ -119,6 +108,7 @@ def unet_forward_XTI(self,
|
|||||||
sample = upsample_block(
|
sample = upsample_block(
|
||||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. post-process
|
# 6. post-process
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
@@ -127,7 +117,8 @@ def unet_forward_XTI(self,
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sample,)
|
return (sample,)
|
||||||
|
|
||||||
return UNet2DConditionOutput(sample=sample)
|
return SampleOutput(sample=sample)
|
||||||
|
|
||||||
|
|
||||||
def downblock_forward_XTI(
|
def downblock_forward_XTI(
|
||||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||||
@@ -166,6 +157,7 @@ def downblock_forward_XTI(
|
|||||||
|
|
||||||
return hidden_states, output_states
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
def upblock_forward_XTI(
|
def upblock_forward_XTI(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
14
fine_tune.py
14
fine_tune.py
@@ -91,7 +91,7 @@ def train(args):
|
|||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
@@ -385,8 +385,8 @@ def train(args):
|
|||||||
epoch,
|
epoch,
|
||||||
num_train_epochs,
|
num_train_epochs,
|
||||||
global_step,
|
global_step,
|
||||||
unwrap_model(text_encoder),
|
accelerator.unwrap_model(text_encoder),
|
||||||
unwrap_model(unet),
|
accelerator.unwrap_model(unet),
|
||||||
vae,
|
vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -428,8 +428,8 @@ def train(args):
|
|||||||
epoch,
|
epoch,
|
||||||
num_train_epochs,
|
num_train_epochs,
|
||||||
global_step,
|
global_step,
|
||||||
unwrap_model(text_encoder),
|
accelerator.unwrap_model(text_encoder),
|
||||||
unwrap_model(unet),
|
accelerator.unwrap_model(unet),
|
||||||
vae,
|
vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -437,8 +437,8 @@ def train(args):
|
|||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
unet = unwrap_model(unet)
|
unet = accelerator.unwrap_model(unet)
|
||||||
text_encoder = unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|||||||
@@ -2904,23 +2904,9 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
log_with=log_with,
|
log_with=log_with,
|
||||||
logging_dir=logging_dir,
|
project_dir=logging_dir,
|
||||||
)
|
)
|
||||||
|
return accelerator
|
||||||
# accelerateの互換性問題を解決する
|
|
||||||
accelerator_0_15 = True
|
|
||||||
try:
|
|
||||||
accelerator.unwrap_model("dummy", True)
|
|
||||||
print("Using accelerator 0.15.0 or above.")
|
|
||||||
except TypeError:
|
|
||||||
accelerator_0_15 = False
|
|
||||||
|
|
||||||
def unwrap_model(model):
|
|
||||||
if accelerator_0_15:
|
|
||||||
return accelerator.unwrap_model(model, True)
|
|
||||||
return accelerator.unwrap_model(model)
|
|
||||||
|
|
||||||
return accelerator, unwrap_model
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_dtype(args: argparse.Namespace):
|
def prepare_dtype(args: argparse.Namespace):
|
||||||
|
|||||||
14
train_db.py
14
train_db.py
@@ -95,7 +95,7 @@ def train(args):
|
|||||||
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
@@ -372,8 +372,8 @@ def train(args):
|
|||||||
epoch,
|
epoch,
|
||||||
num_train_epochs,
|
num_train_epochs,
|
||||||
global_step,
|
global_step,
|
||||||
unwrap_model(text_encoder),
|
accelerator.unwrap_model(text_encoder),
|
||||||
unwrap_model(unet),
|
accelerator.unwrap_model(unet),
|
||||||
vae,
|
vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -420,8 +420,8 @@ def train(args):
|
|||||||
epoch,
|
epoch,
|
||||||
num_train_epochs,
|
num_train_epochs,
|
||||||
global_step,
|
global_step,
|
||||||
unwrap_model(text_encoder),
|
accelerator.unwrap_model(text_encoder),
|
||||||
unwrap_model(unet),
|
accelerator.unwrap_model(unet),
|
||||||
vae,
|
vae,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -429,8 +429,8 @@ def train(args):
|
|||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
unet = unwrap_model(unet)
|
unet = accelerator.unwrap_model(unet)
|
||||||
text_encoder = unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ def train(args):
|
|||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("preparing accelerator")
|
print("preparing accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
@@ -702,7 +702,7 @@ def train(args):
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||||
save_model(ckpt_name, unwrap_model(network), global_step, epoch)
|
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
||||||
|
|
||||||
if args.save_state:
|
if args.save_state:
|
||||||
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||||
@@ -744,7 +744,7 @@ def train(args):
|
|||||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||||
if is_main_process and saving:
|
if is_main_process and saving:
|
||||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
||||||
save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1)
|
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
||||||
|
|
||||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||||
if remove_epoch_no is not None:
|
if remove_epoch_no is not None:
|
||||||
@@ -762,7 +762,7 @@ def train(args):
|
|||||||
metadata["ss_training_finished_at"] = str(time.time())
|
metadata["ss_training_finished_at"] = str(time.time())
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
network = unwrap_model(network)
|
network = accelerator.unwrap_model(network)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ def train(args):
|
|||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
@@ -291,7 +291,7 @@ def train(args):
|
|||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
text_encoder.requires_grad_(True)
|
text_encoder.requires_grad_(True)
|
||||||
@@ -440,7 +440,7 @@ def train(args):
|
|||||||
|
|
||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||||
index_no_updates
|
index_no_updates
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -457,7 +457,9 @@ def train(args):
|
|||||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
updated_embs = (
|
||||||
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
|
)
|
||||||
|
|
||||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||||
@@ -493,7 +495,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||||
@@ -517,7 +519,7 @@ def train(args):
|
|||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
text_encoder = unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import torch
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
|
import library
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
import library.huggingface_util as huggingface_util
|
import library.huggingface_util as huggingface_util
|
||||||
@@ -20,7 +21,14 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction
|
from library.custom_train_functions import (
|
||||||
|
apply_snr_weight,
|
||||||
|
prepare_scheduler_for_custom_training,
|
||||||
|
pyramid_noise_like,
|
||||||
|
apply_noise_offset,
|
||||||
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
)
|
||||||
|
import library.original_unet as original_unet
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
@@ -98,7 +106,7 @@ def train(args):
|
|||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
@@ -257,9 +265,9 @@ def train(args):
|
|||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
original_unet.UNet2DConditionModel.forward = unet_forward_XTI
|
||||||
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||||
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
@@ -319,7 +327,7 @@ def train(args):
|
|||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
# Freeze all parameters except for the token embeddings in text encoder
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
text_encoder.requires_grad_(True)
|
text_encoder.requires_grad_(True)
|
||||||
@@ -473,7 +481,7 @@ def train(args):
|
|||||||
|
|
||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
||||||
index_no_updates
|
index_no_updates
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -490,7 +498,13 @@ def train(args):
|
|||||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
updated_embs = (
|
||||||
|
accelerator.unwrap_model(text_encoder)
|
||||||
|
.get_input_embeddings()
|
||||||
|
.weight[token_ids_XTI]
|
||||||
|
.data.detach()
|
||||||
|
.clone()
|
||||||
|
)
|
||||||
|
|
||||||
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
||||||
save_model(ckpt_name, updated_embs, global_step, epoch)
|
save_model(ckpt_name, updated_embs, global_step, epoch)
|
||||||
@@ -526,7 +540,7 @@ def train(args):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
if args.save_every_n_epochs is not None:
|
||||||
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
||||||
@@ -551,7 +565,7 @@ def train(args):
|
|||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
text_encoder = unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user