support SD3 LoRA

This commit is contained in:
kohya-ss
2024-10-25 21:58:31 +09:00
parent f52fb66e8f
commit d2c549d7b2
7 changed files with 1334 additions and 67 deletions

View File

@@ -129,6 +129,7 @@ class NetworkTrainer:
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
"""
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
"""
return text_encoders
@@ -591,6 +592,7 @@ class NetworkTrainer:
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
unet.requires_grad_(False)