diff --git a/README.md b/README.md index f8a269c7..51183a9a 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,30 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Dec 3, 2023 / 2023/12/3 + +- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details. +- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - The default values are same as the previous version. +- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`. + - `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages. + - `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages. + - `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink. + - `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available. + +- `finetune\tag_images_by_wd14_tagger.py` で `--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。 +- `networks/extract_lora_from_models.py` に `--min_diff` と `--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - デフォルト値は前のバージョンと同じです。 +- `sdxl_gen_img.py` と `gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。 + - `--ds_timesteps_1` と `--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。 + - `--ds_depth_1` と `--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの index)を指定します。 + - `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。 + - `--dst1`、`--dst2`、`--dsd1`、`--dsd2`、`--dsr` プロンプトオプションも使用できます。 + ### Nov 5, 2023 / 2023/11/5 - `sdxl_train.py` now supports different learning rates for each Text Encoder. diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index c871f076..d186bf24 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。 +- `--sample_at_first` + + 学習開始前にサンプル出力します。学習前との比較ができます。 + - `--sample_prompts` サンプル出力用プロンプトのファイルを指定します。 diff --git a/fine_tune.py b/fine_tune.py index b0787677..f72e618b 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -253,9 +253,6 @@ def train(args): else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) @@ -298,6 +295,9 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") diff --git a/library/ipex/__init__.py b/library/ipex/__init__.py index 43accd9f..662572c8 100644 --- a/library/ipex/__init__.py +++ b/library/ipex/__init__.py @@ -4,13 +4,12 @@ import contextlib import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from .hijacks import ipex_hijacks -from .attention import attention_init # pylint: disable=protected-access, missing-function-docstring, line-too-long def ipex_init(): # pylint: disable=too-many-statements try: - #Replace cuda with xpu: + # Replace cuda with xpu: torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.device = torch.xpu.device @@ -30,6 +29,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.FloatTensor = torch.xpu.FloatTensor torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu + torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker @@ -90,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.__file__ = torch.xpu.__file__ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork - #torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing + # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing - #Memory: + # Memory: torch.cuda.memory = torch.xpu.memory if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): torch.xpu.empty_cache = lambda: None @@ -112,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats - #RNG: + # RNG: torch.cuda.get_rng_state = torch.xpu.get_rng_state torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all torch.cuda.set_rng_state = torch.xpu.set_rng_state @@ -123,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.seed_all = torch.xpu.seed_all torch.cuda.initial_seed = torch.xpu.initial_seed - #AMP: + # AMP: torch.cuda.amp = torch.xpu.amp if not hasattr(torch.cuda.amp, "common"): torch.cuda.amp.common = contextlib.nullcontext() @@ -138,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements except Exception: # pylint: disable=broad-exception-caught torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler - #C + # C torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream ipex._C._DeviceProperties.major = 2023 ipex._C._DeviceProperties.minor = 2 - #Fix functions with ipex: + # Fix functions with ipex: torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] torch._utils._get_available_device_type = lambda: "xpu" torch.has_cuda = True @@ -164,12 +164,17 @@ def ipex_init(): # pylint: disable=too-many-statements torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card ipex_hijacks() - attention_init() - try: - from .diffusers import ipex_diffusers - ipex_diffusers() - except Exception: # pylint: disable=broad-exception-caught - pass + if not torch.xpu.has_fp64_dtype(): + try: + from .attention import attention_init + attention_init() + except Exception: # pylint: disable=broad-exception-caught + pass + try: + from .diffusers import ipex_diffusers + ipex_diffusers() + except Exception: # pylint: disable=broad-exception-caught + pass except Exception as e: return False, e return True, None diff --git a/library/ipex/attention.py b/library/ipex/attention.py index 84848b6a..52016466 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -74,6 +74,11 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. shape_one, batch_size_attention, query_tokens, shape_four = query.shape no_shape_one = False + if query.dtype != key.dtype: + key = key.to(dtype=query.dtype) + if query.dtype != value.dtype: + value = value.to(dtype=query.dtype) + block_multiply = query.element_size() slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply block_size = batch_size_attention * slice_block_size diff --git a/library/ipex/diffusers.py b/library/ipex/diffusers.py index 005ee49f..c32af507 100644 --- a/library/ipex/diffusers.py +++ b/library/ipex/diffusers.py @@ -1,6 +1,6 @@ import torch import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import -import diffusers #0.21.1 # pylint: disable=import-error +import diffusers #0.24.0 # pylint: disable=import-error from diffusers.models.attention_processor import Attention # pylint: disable=protected-access, missing-function-docstring, line-too-long diff --git a/library/ipex/gradscaler.py b/library/ipex/gradscaler.py index 53021210..6eb56bc2 100644 --- a/library/ipex/gradscaler.py +++ b/library/ipex/gradscaler.py @@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un # pylint: disable=protected-access, missing-function-docstring, line-too-long +device_supports_fp64 = torch.xpu.has_fp64_dtype() OptState = ipex.cpu.autocast._grad_scaler.OptState _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state @@ -96,7 +97,10 @@ def unscale_(self, optimizer): # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) + if device_supports_fp64: + inv_scale = self._scale.double().reciprocal().float() + else: + inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device) found_inf = torch.full( (1,), 0.0, dtype=torch.float32, device=self._scale.device ) diff --git a/library/ipex/hijacks.py b/library/ipex/hijacks.py index 77ed5419..4a9a3569 100644 --- a/library/ipex/hijacks.py +++ b/library/ipex/hijacks.py @@ -89,6 +89,7 @@ def ipex_autocast(*args, **kwargs): else: return original_autocast(*args, **kwargs) +# Embedding BF16 original_torch_cat = torch.cat def torch_cat(tensor, *args, **kwargs): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): @@ -96,6 +97,7 @@ def torch_cat(tensor, *args, **kwargs): else: return original_torch_cat(tensor, *args, **kwargs) +# Latent antialias: original_interpolate = torch.nn.functional.interpolate def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments if antialias or align_corners is not None: @@ -115,19 +117,29 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name else: return original_linalg_solve(A, B, *args, **kwargs) +@property +def is_cuda(self): + return self.device.type == 'xpu' + def ipex_hijacks(): + CondFunc('torch.tensor', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) CondFunc('torch.Tensor.to', lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) CondFunc('torch.Tensor.cuda', lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) + CondFunc('torch.UntypedStorage.__init__', + lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), + lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.UntypedStorage.cuda', + lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs), + lambda orig_func, self, device=None, *args, **kwargs: check_device(device)) CondFunc('torch.empty', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.load', - lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs), - lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location)) CondFunc('torch.randn', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) @@ -137,17 +149,19 @@ def ipex_hijacks(): CondFunc('torch.zeros', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) - CondFunc('torch.tensor', - lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), - lambda orig_func, *args, device=None, **kwargs: check_device(device)) CondFunc('torch.linspace', lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs), lambda orig_func, *args, device=None, **kwargs: check_device(device)) + CondFunc('torch.load', + lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: + orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs), + lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location)) CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)), lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu") + # TiledVAE and ControlNet: CondFunc('torch.batch_norm', lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input, weight if weight is not None else torch.ones(input.size()[1], device=input.device), @@ -159,38 +173,46 @@ def ipex_hijacks(): bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs), lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu")) - #Functions with dtype errors: + # Functions with dtype errors: CondFunc('torch.nn.modules.GroupNorm.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + # Training: CondFunc('torch.nn.modules.linear.Linear.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + # BF16: CondFunc('torch.nn.functional.layer_norm', lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: weight is not None and input.dtype != weight.data.dtype) + # SwinIR BF16: + CondFunc('torch.nn.functional.pad', + lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16), + lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16) - #Diffusers Float64 (ARC GPUs doesn't support double or Float64): + # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): if not torch.xpu.has_fp64_dtype(): CondFunc('torch.from_numpy', lambda orig_func, ndarray: orig_func(ndarray.astype('float32')), lambda orig_func, ndarray: ndarray.dtype == float) - #Broken functions when torch.cuda.is_available is True: + # Broken functions when torch.cuda.is_available is True: + # Pin Memory: CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__', lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs), lambda orig_func, *args, **kwargs: True) - #Functions that make compile mad with CondFunc: + # Functions that make compile mad with CondFunc: torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers torch.nn.DataParallel = DummyDataParallel torch.autocast = ipex_autocast torch.cat = torch_cat torch.linalg.solve = linalg_solve + torch.UntypedStorage.is_cuda = is_cuda torch.nn.functional.interpolate = interpolate torch.backends.cuda.sdp_kernel = return_null_context diff --git a/library/original_unet.py b/library/original_unet.py index 938b0b64..00997e7c 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -586,6 +586,9 @@ class CrossAttention(nn.Module): self.use_memory_efficient_attention_mem_eff = False self.use_sdpa = False + # Attention processor + self.processor = None + def set_use_memory_efficient_attention(self, xformers, mem_eff): self.use_memory_efficient_attention_xformers = xformers self.use_memory_efficient_attention_mem_eff = mem_eff @@ -607,7 +610,28 @@ class CrossAttention(nn.Module): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def forward(self, hidden_states, context=None, mask=None): + def set_processor(self): + return self.processor + + def get_processor(self): + return self.processor + + def forward(self, hidden_states, context=None, mask=None, **kwargs): + if self.processor is not None: + ( + hidden_states, + encoder_hidden_states, + attention_mask, + ) = translate_attention_names_from_diffusers( + hidden_states=hidden_states, context=context, mask=mask, **kwargs + ) + return self.processor( + attn=self, + hidden_states=hidden_states, + encoder_hidden_states=context, + attention_mask=mask, + **kwargs + ) if self.use_memory_efficient_attention_xformers: return self.forward_memory_efficient_xformers(hidden_states, context, mask) if self.use_memory_efficient_attention_mem_eff: @@ -720,6 +744,21 @@ class CrossAttention(nn.Module): out = self.to_out[0](out) return out +def translate_attention_names_from_diffusers( + hidden_states: torch.FloatTensor, + context: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + # HF naming + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None +): + # translate from hugging face diffusers + context = context if context is not None else encoder_hidden_states + + # translate from hugging face diffusers + mask = mask if mask is not None else attention_mask + + return hidden_states, context, mask # feedforward class GEGLU(nn.Module): @@ -1350,7 +1389,7 @@ class UNet2DConditionModel(nn.Module): self.out_channels = OUT_CHANNELS self.sample_size = sample_size - self.prepare_config() + self.prepare_config(sample_size=sample_size) # state_dictの書式が変わるのでmoduleの持ち方は変えられない @@ -1437,8 +1476,8 @@ class UNet2DConditionModel(nn.Module): self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) # region diffusers compatibility - def prepare_config(self): - self.config = SimpleNamespace() + def prepare_config(self, *args, **kwargs): + self.config = SimpleNamespace(**kwargs) @property def dtype(self) -> torch.dtype: diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 2f0154ca..a844927c 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -133,6 +133,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) + # temporary workaround for text_projection.weight.weight for Playground-v2 + if "text_projection.weight.weight" in new_sd: + print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight") + new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"] + del new_sd["text_projection.weight.weight"] + return new_sd, logit_scale @@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) - + # 一部のposition_idsがないモデルへの対応 / add position_ids for some models if "text_model.embeddings.position_ids" not in te1_sd: te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f637d993..5ad748d1 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -51,8 +51,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): torch.cuda.empty_cache() accelerator.wait_for_everyone() - text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet]) - return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index 9fb616ed..2b051e1f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2665,7 +2665,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -2979,6 +2979,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" ) + parser.add_argument( + "--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する" + ) parser.add_argument( "--sample_every_n_epochs", type=int, @@ -3384,7 +3387,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -3488,6 +3491,20 @@ def get_optimizer(args, trainable_params): optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW".lower(): + print(f"use PagedAdamW optimizer | {optimizer_kwargs}") + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです") + try: + optimizer_class = bnb.optim.PagedAdamW + except AttributeError: + raise AttributeError( + "No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + elif optimizer_type == "PagedAdamW32bit".lower(): print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}") try: @@ -3897,17 +3914,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une return text_encoder, vae, unet, load_stable_diffusion_format -# TODO remove this function in the future -def transform_if_model_is_DDP(text_encoder, unet, network=None): - # Transform text_encoder, unet and network from DistributedDataParallel - return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None) - - -def transform_models_if_DDP(models): - # Transform text_encoder, unet and network from DistributedDataParallel - return [model.module if type(model) == DDP else model for model in models if model is not None] - - def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False): # load models for each process for pi in range(accelerator.state.num_processes): @@ -3931,8 +3937,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio torch.cuda.empty_cache() accelerator.wait_for_everyone() - text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet) - return text_encoder, vae, unet, load_stable_diffusion_format @@ -4447,11 +4451,118 @@ SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" +def get_my_scheduler( + *, + sample_sampler: str, + v_parameterization: bool, +): + sched_init_args = {} + if sample_sampler == "ddim": + scheduler_cls = DDIMScheduler + elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + elif sample_sampler == "pndm": + scheduler_cls = PNDMScheduler + elif sample_sampler == "lms" or sample_sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + elif sample_sampler == "euler" or sample_sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = sample_sampler + elif sample_sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + elif sample_sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + else: + scheduler_cls = DDIMScheduler + + if v_parameterization: + sched_init_args["prediction_type"] = "v_prediction" + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # clip_sample=Trueにする + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is False + ): + # print("set clip_sample to True") + scheduler.config.clip_sample = True + + return scheduler + def sample_images(*args, **kwargs): return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs) +def line_to_prompt_dict(line: str) -> dict: + # subset of gen_img_diffusers + prompt_args = line.split(" --") + prompt_dict = {} + prompt_dict['prompt'] = prompt_args[0] + + for parg in prompt_args: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['width'] = int(m.group(1)) + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['height'] = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + prompt_dict['seed'] = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + prompt_dict['scale'] = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + prompt_dict['negative_prompt'] = m.group(1) + continue + + m = re.match(r"ss (.+)", parg, re.IGNORECASE) + if m: + prompt_dict['sample_sampler'] = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: + prompt_dict['controlnet_image'] = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + return prompt_dict + def sample_images_common( pipe_class, accelerator, @@ -4469,15 +4580,19 @@ def sample_images_common( """ StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した """ - if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: - return - if args.sample_every_n_epochs is not None: - # sample_every_n_steps は無視する - if epoch is None or epoch % args.sample_every_n_epochs != 0: + if steps == 0: + if not args.sample_at_first: return else: - if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + if args.sample_every_n_steps is None and args.sample_every_n_epochs is None: return + if args.sample_every_n_epochs is not None: + # sample_every_n_steps は無視する + if epoch is None or epoch % args.sample_every_n_epochs != 0: + return + else: + if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch + return print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}") if not os.path.isfile(args.sample_prompts): @@ -4504,56 +4619,19 @@ def sample_images_common( with open(args.sample_prompts, "r", encoding="utf-8") as f: prompts = json.load(f) - # schedulerを用意する - sched_init_args = {} - if args.sample_sampler == "ddim": - scheduler_cls = DDIMScheduler - elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - elif args.sample_sampler == "pndm": - scheduler_cls = PNDMScheduler - elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sample_sampler - elif args.sample_sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - elif args.sample_sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - else: - scheduler_cls = DDIMScheduler - - if args.v_parameterization: - sched_init_args["prediction_type"] = "v_prediction" - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, + schedulers: dict = {} + default_scheduler = get_my_scheduler( + sample_sampler=args.sample_sampler, + v_parameterization=args.v_parameterization, ) - - # clip_sample=Trueにする - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - scheduler.config.clip_sample = True + schedulers[args.sample_sampler] = default_scheduler pipeline = pipe_class( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=default_scheduler, safety_checker=None, feature_extractor=None, requires_safety_checker=False, @@ -4569,78 +4647,34 @@ def sample_images_common( with torch.no_grad(): # with accelerator.autocast(): - for i, prompt in enumerate(prompts): + for i, prompt_dict in enumerate(prompts): if not accelerator.is_main_process: continue - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - controlnet_image = prompt.get("controlnet_image") - prompt = prompt.get("prompt") - else: - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - controlnet_image = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt - controlnet_image = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) + assert isinstance(prompt_dict, dict) + negative_prompt = prompt_dict.get("negative_prompt") + sample_steps = prompt_dict.get("sample_steps", 30) + width = prompt_dict.get("width", 512) + height = prompt_dict.get("height", 512) + scale = prompt_dict.get("scale", 7.5) + seed = prompt_dict.get("seed") + controlnet_image = prompt_dict.get("controlnet_image") + prompt: str = prompt_dict.get("prompt", "") + sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler) if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + scheduler = schedulers.get(sampler_name) + if scheduler is None: + scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,) + schedulers[sampler_name] = scheduler + pipeline.scheduler = scheduler + if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) if negative_prompt is not None: @@ -4658,6 +4692,9 @@ def sample_images_common( print(f"width: {width}") print(f"sample_steps: {sample_steps}") print(f"scale: {scale}") + print(f"sample_sampler: {sampler_name}") + if seed is not None: + print(f"seed: {seed}") with accelerator.autocast(): latents = pipeline( prompt=prompt, diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index bfe2e512..29726821 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -515,7 +515,8 @@ class PipelineLike: uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: @@ -578,9 +579,11 @@ class PipelineLike: text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device) diff --git a/sdxl_train.py b/sdxl_train.py index fd775624..05ad0878 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -397,13 +397,10 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい if train_unet: unet = accelerator.prepare(unet) - (unet,) = train_util.transform_models_if_DDP([unet]) if train_text_encoder1: text_encoder1 = accelerator.prepare(text_encoder1) - (text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1]) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - (text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2]) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) @@ -461,6 +458,11 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + sdxl_train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet + ) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 44447d1f..cb97859f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -283,9 +283,6 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # transform DDP after prepare (train_network here only) - unet = train_util.transform_models_if_DDP([unet])[0] - if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 91cbacc6..87f30301 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -254,9 +254,6 @@ def train(args): ) network: control_net_lllite.ControlNetLLLite - # transform DDP after prepare (train_network here only) - unet, network = train_util.transform_models_if_DDP([unet, network]) - if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: diff --git a/train_controlnet.py b/train_controlnet.py index e0118d1c..1f3dbae3 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -11,10 +11,13 @@ import toml from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -335,7 +338,9 @@ def train(args): init_kwargs = {} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) loss_recorder = train_util.LossRecorder() del train_dataset_group @@ -371,6 +376,11 @@ def train(args): accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + train_util.sample_images( + accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet + ) + # training loop for epoch in range(num_train_epochs): if is_main_process: diff --git a/train_db.py b/train_db.py index 966999df..5518740f 100644 --- a/train_db.py +++ b/train_db.py @@ -112,6 +112,7 @@ def train(args): # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype # モデルを読み込む text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) @@ -136,7 +137,7 @@ def train(args): # 学習を準備する if cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): @@ -225,9 +226,6 @@ def train(args): else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -274,6 +272,9 @@ def train(args): init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + # For --sample_at_first + train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + loss_recorder = train_util.LossRecorder() for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -482,6 +483,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない", ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) return parser diff --git a/train_network.py b/train_network.py index 1cbed2e7..378a3390 100644 --- a/train_network.py +++ b/train_network.py @@ -12,6 +12,7 @@ import toml from tqdm import tqdm import torch +from torch.nn.parallel import DistributedDataParallel as DDP try: import intel_extension_for_pytorch as ipex @@ -127,6 +128,11 @@ class NetworkTrainer: noise_pred = unet(noisy_latents, timesteps, text_conds).sample return noise_pred + def all_reduce_network(self, accelerator, network): + for param in network.parameters(): + if param.grad is not None: + param.grad = accelerator.reduce(param.grad, reduction="mean") + def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) @@ -390,46 +396,20 @@ class NetworkTrainer: # acceleratorがなんかよろしくやってくれるらしい # TODO めちゃくちゃ冗長なのでコードを整理する - if train_unet and train_text_encoder: + if train_unet: + unet = accelerator.prepare(unet) + else: + unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + if train_text_encoder: if len(text_encoders) > 1: - unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler - ) - text_encoder = text_encoders = [t_enc1, t_enc2] - del t_enc1, t_enc2 + text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] else: - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) + text_encoder = accelerator.prepare(text_encoder) text_encoders = [text_encoder] - elif train_unet: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) + else: for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - elif train_text_encoder: - if len(text_encoders) > 1: - t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler - ) - text_encoder = text_encoders = [t_enc1, t_enc2] - del t_enc1, t_enc2 - else: - text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, network, optimizer, train_dataloader, lr_scheduler - ) - text_encoders = [text_encoder] - - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator - else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler - ) - - # transform DDP after prepare (train_network here only) - text_encoders = train_util.transform_models_if_DDP(text_encoders) - unet, network = train_util.transform_models_if_DDP([unet, network]) + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required @@ -451,7 +431,7 @@ class NetworkTrainer: del t_enc - network.prepare_grad_etc(text_encoder, unet) + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) @@ -714,8 +694,8 @@ class NetworkTrainer: del train_dataset_group # callback for step start - if hasattr(network, "on_step_start"): - on_step_start = network.on_step_start + if hasattr(accelerator.unwrap_model(network), "on_step_start"): + on_step_start = accelerator.unwrap_model(network).on_step_start else: on_step_start = lambda *args, **kwargs: None @@ -743,6 +723,9 @@ class NetworkTrainer: accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + # training loop for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -750,7 +733,7 @@ class NetworkTrainer: metadata["ss_epoch"] = str(epoch + 1) - network.on_epoch_start(text_encoder, unet) + accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): current_step.value = global_step @@ -823,8 +806,9 @@ class NetworkTrainer: loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) + self.all_reduce_network(accelerator, network) # sync DDP grad manually if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = network.get_trainable_params() + params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() @@ -832,7 +816,7 @@ class NetworkTrainer: optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: - keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization( + keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 45a437b9..877ac838 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -7,10 +7,13 @@ import toml from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -415,15 +418,11 @@ class TextualInversionTrainer: text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder_or_list, optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet) elif len(text_encoders) == 2: text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet) text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] @@ -529,6 +528,20 @@ class TextualInversionTrainer: accelerator.print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) + # For --sample_at_first + self.sample_images( + accelerator, + args, + 0, + global_step, + accelerator.device, + vae, + tokenizer_or_list, + text_encoder_or_list, + unet, + prompt_replacement, + ) + # training loop for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index f77ad2eb..42d69d2d 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -333,9 +333,6 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler ) - # transform DDP after prepare - text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet) - index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()