Merge branch 'dev' into train_resume_step

This commit is contained in:
Kohya S
2024-06-11 19:27:37 +09:00
5 changed files with 11 additions and 10 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: typos-action - name: typos-action
uses: crate-ci/typos@v1.19.0 uses: crate-ci/typos@v1.21.0

View File

@@ -2,6 +2,7 @@
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started # Instruction: https://github.com/marketplace/actions/typos-action#getting-started
[default.extend-identifiers] [default.extend-identifiers]
ddPn08="ddPn08"
[default.extend-words] [default.extend-words]
NIN="NIN" NIN="NIN"
@@ -27,6 +28,7 @@ rik="rik"
koo="koo" koo="koo"
yos="yos" yos="yos"
wn="wn" wn="wn"
hime="hime"
[files] [files]

View File

@@ -5,7 +5,7 @@ from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))

View File

@@ -7,8 +7,10 @@ from typing import Optional, List, Type
import torch import torch
from library import sdxl_original_unet from library import sdxl_original_unet
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied # input_blocksに適用するかどうか / if True, input_blocks are not applied
@@ -103,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR):
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
self.cond_image = None self.cond_image = None
self.cond_emb = None
def set_cond_image(self, cond_image): def set_cond_image(self, cond_image):
self.cond_image = cond_image self.cond_image = cond_image
self.cond_emb = None
def forward(self, x): def forward(self, x):
if not self.enabled: if not self.enabled:
return super().forward(x) return super().forward(x)
if self.cond_emb is None: cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
# reshape / b,c,h,w -> b,h*w,c # reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape n, c, h, w = cx.shape
@@ -159,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D):
if not self.enabled: if not self.enabled:
return super().forward(x) return super().forward(x)
if self.cond_emb is None: cx = self.lllite_conditioning1(self.cond_image)
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = torch.cat([cx, self.down(x)], dim=1) cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx) cx = self.mid(cx)

View File

@@ -289,6 +289,9 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else: else: