add FLUX.1 LoRA training

This commit is contained in:
Kohya S
2024-08-09 22:56:48 +09:00
parent da4d0fe016
commit 36b2e6fc28
10 changed files with 2992 additions and 55 deletions

View File

@@ -15,6 +15,12 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import CLIPTokenizer, T5TokenizerFast
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
memory_efficient_attention = None
@@ -95,7 +101,9 @@ class SDTokenizer:
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
# truncate to max_length
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
print(
f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}"
)
if truncate_to_max_length and len(batch) > self.max_length:
batch = batch[: self.max_length]
if truncate_length is not None and len(batch) > truncate_length:
@@ -1554,6 +1562,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.set_clip_options({"layer": layer_idx})
self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def gradient_checkpointing_enable(self):
logger.warning("Gradient checkpointing is not supported for this model")
def set_attn_mode(self, mode):
raise NotImplementedError("This model does not support setting the attention mode")
@@ -1925,6 +1944,7 @@ def create_clip_l(device="cpu", dtype=torch.float32, state_dict: Optional[Dict[s
return_projected_pooled=False,
textmodel_json_config=CLIPL_CONFIG,
)
clip_l.gradient_checkpointing_enable()
if state_dict is not None:
# update state_dict if provided to include logit_scale and text_projection.weight avoid errors
if "logit_scale" not in state_dict: