mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: add block swap for FLUX.1/SD3 LoRA training
This commit is contained in:
@@ -601,8 +601,10 @@ class NetworkTrainer:
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
|
||||
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
|
||||
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
|
||||
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
|
||||
Reference in New Issue
Block a user