mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix bug in FLUX multi GPU training
This commit is contained in:
@@ -174,7 +174,7 @@ def train(args):
|
||||
# load VAE for caching latents
|
||||
ae = None
|
||||
if cache_latents:
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
@@ -199,8 +199,8 @@ def train(args):
|
||||
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
|
||||
|
||||
# load clip_l, t5xxl for caching text encoder outputs
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
t5xxl.eval()
|
||||
clip_l.requires_grad_(False)
|
||||
@@ -228,7 +228,6 @@ def train(args):
|
||||
if args.sample_prompts is not None:
|
||||
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
|
||||
|
||||
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
prompts = load_prompts(args.sample_prompts)
|
||||
@@ -238,9 +237,9 @@ def train(args):
|
||||
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||
if p not in sample_prompts_te_outputs:
|
||||
logger.info(f"cache Text Encoder outputs for prompt: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -251,7 +250,9 @@ def train(args):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# load FLUX
|
||||
flux = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||
flux = flux_utils.load_flow_model(
|
||||
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
flux.enable_gradient_checkpointing(args.cpu_offload_checkpointing)
|
||||
@@ -419,7 +420,7 @@ def train(args):
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
|
||||
if is_swapping_blocks:
|
||||
flux.move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
@@ -439,8 +440,8 @@ def train(args):
|
||||
|
||||
double_blocks_to_swap = args.double_blocks_to_swap
|
||||
single_blocks_to_swap = args.single_blocks_to_swap
|
||||
num_double_blocks = len(flux.double_blocks)
|
||||
num_single_blocks = len(flux.single_blocks)
|
||||
num_double_blocks = 19 # len(flux.double_blocks)
|
||||
num_single_blocks = 38 # len(flux.single_blocks)
|
||||
handled_double_block_indices = set()
|
||||
handled_single_block_indices = set()
|
||||
|
||||
@@ -537,8 +538,8 @@ def train(args):
|
||||
|
||||
double_blocks_to_swap = args.double_blocks_to_swap
|
||||
single_blocks_to_swap = args.single_blocks_to_swap
|
||||
num_double_blocks = len(flux.double_blocks)
|
||||
num_single_blocks = len(flux.single_blocks)
|
||||
num_double_blocks = 19 # len(flux.double_blocks)
|
||||
num_single_blocks = 38 # len(flux.single_blocks)
|
||||
|
||||
for opt_idx, optimizer in enumerate(optimizers):
|
||||
for param_group in optimizer.param_groups:
|
||||
@@ -618,7 +619,7 @@ def train(args):
|
||||
)
|
||||
|
||||
if is_swapping_blocks:
|
||||
flux.prepare_block_swap_before_forward()
|
||||
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
|
||||
|
||||
# For --sample_at_first
|
||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
|
||||
@@ -660,7 +661,7 @@ def train(args):
|
||||
with torch.no_grad():
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
Reference in New Issue
Block a user