mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
add FLUX.1 LoRA training
This commit is contained in:
@@ -15,6 +15,12 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
memory_efficient_attention = None
|
||||
@@ -95,7 +101,9 @@ class SDTokenizer:
|
||||
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
||||
|
||||
# truncate to max_length
|
||||
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
|
||||
print(
|
||||
f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}"
|
||||
)
|
||||
if truncate_to_max_length and len(batch) > self.max_length:
|
||||
batch = batch[: self.max_length]
|
||||
if truncate_length is not None and len(batch) > truncate_length:
|
||||
@@ -1554,6 +1562,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.set_clip_options({"layer": layer_idx})
|
||||
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
logger.warning("Gradient checkpointing is not supported for this model")
|
||||
|
||||
def set_attn_mode(self, mode):
|
||||
raise NotImplementedError("This model does not support setting the attention mode")
|
||||
|
||||
@@ -1925,6 +1944,7 @@ def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[s
|
||||
return_projected_pooled=False,
|
||||
textmodel_json_config=CLIPL_CONFIG,
|
||||
)
|
||||
clip_l.gradient_checkpointing_enable()
|
||||
if state_dict is not None:
|
||||
# update state_dict if provided to include logit_scale and text_projection.weight avoid errors
|
||||
if "logit_scale" not in state_dict:
|
||||
|
||||
Reference in New Issue
Block a user