remove workaround for accelerator=0.15, fix XTI

This commit is contained in:
ykume
2023-06-11 18:32:14 +09:00
parent 33a6234b52
commit 0315611b11
7 changed files with 153 additions and 159 deletions

View File

@@ -2,75 +2,62 @@ import torch
from typing import Union, List, Optional, Dict, Any, Tuple from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput from diffusers.models.unet_2d_condition import UNet2DConditionOutput
def unet_forward_XTI(self, from library.original_unet import SampleOutput
def unet_forward_XTI(
self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]: ) -> Union[Dict, Tuple]:
r""" r"""
Args: Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a dict instead of a plain tuple.
Returns: Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: `SampleOutput` or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
returning a tuple, the first element is the sample tensor.
""" """
# By default samples have to be AT least a multiple of the overall upsampling factor. # By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size # However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary. # on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**self.num_upsamplers default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False forward_upsample_size = False
upsample_size = None upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.") # logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True forward_upsample_size = True
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time # 1. time
timesteps = timestep timesteps = timestep
if not torch.is_tensor(timesteps): timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors # timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here. # but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this. # there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=self.dtype) t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb) emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
@@ -78,7 +65,9 @@ def unet_forward_XTI(self,
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
down_i = 0 down_i = 0
for downsample_block in self.down_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
@@ -99,14 +88,14 @@ def unet_forward_XTI(self,
is_final_block = i == len(self.up_blocks) - 1 is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
# if we have not reached the final block and need to forward the # if we have not reached the final block and need to forward the upsample size, we do it here
# upsample size, we do it here # 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size: if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:] upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: if upsample_block.has_cross_attention:
sample = upsample_block( sample = upsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
@@ -119,6 +108,7 @@ def unet_forward_XTI(self,
sample = upsample_block( sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
) )
# 6. post-process # 6. post-process
sample = self.conv_norm_out(sample) sample = self.conv_norm_out(sample)
sample = self.conv_act(sample) sample = self.conv_act(sample)
@@ -127,7 +117,8 @@ def unet_forward_XTI(self,
if not return_dict: if not return_dict:
return (sample,) return (sample,)
return UNet2DConditionOutput(sample=sample) return SampleOutput(sample=sample)
def downblock_forward_XTI( def downblock_forward_XTI(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
@@ -166,6 +157,7 @@ def downblock_forward_XTI(
return hidden_states, output_states return hidden_states, output_states
def upblock_forward_XTI( def upblock_forward_XTI(
self, self,
hidden_states, hidden_states,

View File

@@ -91,7 +91,7 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -385,8 +385,8 @@ def train(args):
epoch, epoch,
num_train_epochs, num_train_epochs,
global_step, global_step,
unwrap_model(text_encoder), accelerator.unwrap_model(text_encoder),
unwrap_model(unet), accelerator.unwrap_model(unet),
vae, vae,
) )
@@ -428,8 +428,8 @@ def train(args):
epoch, epoch,
num_train_epochs, num_train_epochs,
global_step, global_step,
unwrap_model(text_encoder), accelerator.unwrap_model(text_encoder),
unwrap_model(unet), accelerator.unwrap_model(unet),
vae, vae,
) )
@@ -437,8 +437,8 @@ def train(args):
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
unet = unwrap_model(unet) unet = accelerator.unwrap_model(unet)
text_encoder = unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training() accelerator.end_training()

View File

@@ -2904,23 +2904,9 @@ def prepare_accelerator(args: argparse.Namespace):
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
log_with=log_with, log_with=log_with,
logging_dir=logging_dir, project_dir=logging_dir,
) )
return accelerator
# accelerateの互換性問題を解決する
accelerator_0_15 = True
try:
accelerator.unwrap_model("dummy", True)
print("Using accelerator 0.15.0 or above.")
except TypeError:
accelerator_0_15 = False
def unwrap_model(model):
if accelerator_0_15:
return accelerator.unwrap_model(model, True)
return accelerator.unwrap_model(model)
return accelerator, unwrap_model
def prepare_dtype(args: argparse.Namespace): def prepare_dtype(args: argparse.Namespace):

View File

@@ -95,7 +95,7 @@ def train(args):
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です" f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
) )
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -372,8 +372,8 @@ def train(args):
epoch, epoch,
num_train_epochs, num_train_epochs,
global_step, global_step,
unwrap_model(text_encoder), accelerator.unwrap_model(text_encoder),
unwrap_model(unet), accelerator.unwrap_model(unet),
vae, vae,
) )
@@ -420,8 +420,8 @@ def train(args):
epoch, epoch,
num_train_epochs, num_train_epochs,
global_step, global_step,
unwrap_model(text_encoder), accelerator.unwrap_model(text_encoder),
unwrap_model(unet), accelerator.unwrap_model(unet),
vae, vae,
) )
@@ -429,8 +429,8 @@ def train(args):
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
unet = unwrap_model(unet) unet = accelerator.unwrap_model(unet)
text_encoder = unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training() accelerator.end_training()

View File

@@ -150,7 +150,7 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("preparing accelerator") print("preparing accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
@@ -702,7 +702,7 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, unwrap_model(network), global_step, epoch) save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
if args.save_state: if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step) train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -744,7 +744,7 @@ def train(args):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving: if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1) save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None: if remove_epoch_no is not None:
@@ -762,7 +762,7 @@ def train(args):
metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_training_finished_at"] = str(time.time())
if is_main_process: if is_main_process:
network = unwrap_model(network) network = accelerator.unwrap_model(network)
accelerator.end_training() accelerator.end_training()

View File

@@ -98,7 +98,7 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -291,7 +291,7 @@ def train(args):
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates)) # print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True) text_encoder.requires_grad_(True)
@@ -440,7 +440,7 @@ def train(args):
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad(): with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates index_no_updates
] ]
@@ -457,7 +457,9 @@ def train(args):
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() updated_embs = (
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
)
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, updated_embs, global_step, epoch) save_model(ckpt_name, updated_embs, global_step, epoch)
@@ -493,7 +495,7 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
@@ -517,7 +519,7 @@ def train(args):
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
text_encoder = unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training() accelerator.end_training()

View File

@@ -11,6 +11,7 @@ import torch
from accelerate.utils import set_seed from accelerate.utils import set_seed
import diffusers import diffusers
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
import library
import library.train_util as train_util import library.train_util as train_util
import library.huggingface_util as huggingface_util import library.huggingface_util as huggingface_util
@@ -20,7 +21,14 @@ from library.config_util import (
BlueprintGenerator, BlueprintGenerator,
) )
import library.custom_train_functions as custom_train_functions import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction from library.custom_train_functions import (
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [ imagenet_templates_small = [
@@ -98,7 +106,7 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -257,9 +265,9 @@ def train(args):
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI original_unet.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
@@ -319,7 +327,7 @@ def train(args):
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates)) # print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True) text_encoder.requires_grad_(True)
@@ -473,7 +481,7 @@ def train(args):
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad(): with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates index_no_updates
] ]
@@ -490,7 +498,13 @@ def train(args):
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() updated_embs = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[token_ids_XTI]
.data.detach()
.clone()
)
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, updated_embs, global_step, epoch) save_model(ckpt_name, updated_embs, global_step, epoch)
@@ -526,7 +540,7 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone() updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
@@ -551,7 +565,7 @@ def train(args):
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
text_encoder = unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training() accelerator.end_training()