mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support SD3 LoRA
This commit is contained in:
@@ -129,6 +129,7 @@ class NetworkTrainer:
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
"""
|
||||
Returns a list of models that will be used for text encoding. SDXL uses wrapped and unwrapped models.
|
||||
FLUX.1 and SD3 may cache some outputs of the text encoder, so return the models that will be used for encoding (not cached).
|
||||
"""
|
||||
return text_encoders
|
||||
|
||||
@@ -591,6 +592,7 @@ class NetworkTrainer:
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
|
||||
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
|
||||
unet.requires_grad_(False)
|
||||
|
||||
Reference in New Issue
Block a user