mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: implement block swapping for FLUX.1 LoRA (WIP)
This commit is contained in:
@@ -18,6 +18,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DDPMScheduler
|
||||
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
|
||||
|
||||
@@ -272,6 +273,11 @@ class NetworkTrainer:
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
return accelerator.prepare(unet)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
pass
|
||||
|
||||
@@ -627,7 +633,8 @@ class NetworkTrainer:
|
||||
training_model = ds_model
|
||||
else:
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
# default implementation is: unet = accelerator.prepare(unet)
|
||||
unet = self.prepare_unet_with_accelerator(args, accelerator, unet) # accelerator does some magic here
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
|
||||
Reference in New Issue
Block a user