feat: implement block swapping for FLUX.1 LoRA (WIP)

This commit is contained in:
Kohya S
2024-11-12 08:49:05 +09:00
parent 7feaae5f06
commit cde90b8903
5 changed files with 87 additions and 5 deletions

View File

@@ -18,6 +18,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from accelerate import Accelerator
from diffusers import DDPMScheduler
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
@@ -272,6 +273,11 @@ class NetworkTrainer:
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass
@@ -627,7 +633,8 @@ class NetworkTrainer:
training_model = ds_model
else:
if train_unet:
unet = accelerator.prepare(unet)
# default implementation is: unet = accelerator.prepare(unet)
unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder: