Fix bug in FLUX multi GPU training

This commit is contained in:
kohya-ss
2024-08-22 12:37:41 +09:00
parent e1cd19c0c0
commit 98c91a7625
8 changed files with 156 additions and 38 deletions

View File

@@ -174,7 +174,7 @@ def train(args):
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
@@ -199,8 +199,8 @@ def train(args):
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
# load clip_l, t5xxl for caching text encoder outputs
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
clip_l.eval()
t5xxl.eval()
clip_l.requires_grad_(False)
@@ -228,7 +228,6 @@ def train(args):
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = load_prompts(args.sample_prompts)
@@ -238,9 +237,9 @@ def train(args):
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
accelerator.wait_for_everyone()
@@ -251,7 +250,9 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
flux = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
@@ -419,7 +420,7 @@ def train(args):
# if we doesn't swap blocks, we can move the model to device
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
@@ -439,8 +440,8 @@ def train(args):
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
num_double_blocks = len(flux.double_blocks)
num_single_blocks = len(flux.single_blocks)
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
handled_double_block_indices = set()
handled_single_block_indices = set()
@@ -537,8 +538,8 @@ def train(args):
double_blocks_to_swap = args.double_blocks_to_swap
single_blocks_to_swap = args.single_blocks_to_swap
num_double_blocks = len(flux.double_blocks)
num_single_blocks = len(flux.single_blocks)
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
@@ -618,7 +619,7 @@ def train(args):
)
if is_swapping_blocks:
flux.prepare_block_swap_before_forward()
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
# For --sample_at_first
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
@@ -660,7 +661,7 @@ def train(args):
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]