diff --git a/flux_train.py b/flux_train.py index ecb3c7dd..b294ce42 100644 --- a/flux_train.py +++ b/flux_train.py @@ -251,7 +251,6 @@ def train(args): clean_memory_on_device(accelerator.device) # load FLUX - # if we load to cpu, flux.to(fp8) takes a long time flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") if args.gradient_checkpointing: @@ -259,7 +258,8 @@ def train(args): flux.requires_grad_(True) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + is_swapping_blocks = args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None + if is_swapping_blocks: # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes. # This idea is based on 2kpr's great work. Thank you! logger.info( @@ -412,8 +412,11 @@ def train(args): training_models = [ds_model] else: - # acceleratorがなんかよろしくやってくれるらしい - flux = accelerator.prepare(flux) + # accelerator does some magic + # if we doesn't swap blocks, we can move the model to device + flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks]) + if is_swapping_blocks: + flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする @@ -539,7 +542,7 @@ def train(args): init_kwargs=init_kwargs, ) - if args.double_blocks_to_swap is not None or args.single_blocks_to_swap is not None: + if is_swapping_blocks: flux.prepare_block_swap_before_forward() # For --sample_at_first @@ -595,7 +598,7 @@ def train(args): # get noisy model input and timesteps noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, accelerator.device, weight_dtype + args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype ) # pack latents and get img_ids diff --git a/library/flux_models.py b/library/flux_models.py index 3f44068f..11ef647a 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -953,6 +953,22 @@ class Flux(nn.Module): self.double_blocks_to_swap = double_blocks self.single_blocks_to_swap = single_blocks + def move_to_device_except_swap_blocks(self, device: torch.device): + # assume model is on cpu + if self.double_blocks_to_swap: + save_double_blocks = self.double_blocks + self.double_blocks = None + if self.single_blocks_to_swap: + save_single_blocks = self.single_blocks + self.single_blocks = None + + self.to(device) + + if self.double_blocks_to_swap: + self.double_blocks = save_double_blocks + if self.single_blocks_to_swap: + self.single_blocks = save_single_blocks + def prepare_block_swap_before_forward(self): # move last n blocks to cpu: they are on cuda if self.double_blocks_to_swap: