mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: add block swap for FLUX.1/SD3 LoRA training
This commit is contained in:
@@ -51,6 +51,10 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.max_token_length is not None:
|
||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
# enumerate resolutions from dataset for positional embeddings
|
||||
@@ -83,6 +87,17 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}")
|
||||
elif mmdit.dtype == torch.float8_e4m3fn:
|
||||
logger.info("Loaded fp8 SD3 model")
|
||||
else:
|
||||
logger.info(
|
||||
"Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
|
||||
" / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
|
||||
)
|
||||
mmdit.to(torch.float8_e4m3fn)
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
clip_l = sd3_utils.load_clip_l(
|
||||
args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
|
||||
@@ -432,9 +447,24 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
mmdit: sd3_models.MMDiT = unet
|
||||
mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward()
|
||||
|
||||
return mmdit
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
sd3_train_utils.add_sd3_training_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user