mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into train_resume_step
This commit is contained in:
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -18,4 +18,4 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: typos-action
|
- name: typos-action
|
||||||
uses: crate-ci/typos@v1.19.0
|
uses: crate-ci/typos@v1.21.0
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
||||||
|
|
||||||
[default.extend-identifiers]
|
[default.extend-identifiers]
|
||||||
|
ddPn08="ddPn08"
|
||||||
|
|
||||||
[default.extend-words]
|
[default.extend-words]
|
||||||
NIN="NIN"
|
NIN="NIN"
|
||||||
@@ -27,6 +28,7 @@ rik="rik"
|
|||||||
koo="koo"
|
koo="koo"
|
||||||
yos="yos"
|
yos="yos"
|
||||||
wn="wn"
|
wn="wn"
|
||||||
|
hime="hime"
|
||||||
|
|
||||||
|
|
||||||
[files]
|
[files]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from functools import cache
|
|||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
|
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
||||||
|
|
||||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||||
|
|||||||
@@ -7,8 +7,10 @@ from typing import Optional, List, Type
|
|||||||
import torch
|
import torch
|
||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||||
@@ -103,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR):
|
|||||||
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
||||||
|
|
||||||
self.cond_image = None
|
self.cond_image = None
|
||||||
self.cond_emb = None
|
|
||||||
|
|
||||||
def set_cond_image(self, cond_image):
|
def set_cond_image(self, cond_image):
|
||||||
self.cond_image = cond_image
|
self.cond_image = cond_image
|
||||||
self.cond_emb = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
if self.cond_emb is None:
|
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
|
||||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
|
||||||
cx = self.cond_emb
|
|
||||||
|
|
||||||
# reshape / b,c,h,w -> b,h*w,c
|
# reshape / b,c,h,w -> b,h*w,c
|
||||||
n, c, h, w = cx.shape
|
n, c, h, w = cx.shape
|
||||||
@@ -159,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D):
|
|||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
if self.cond_emb is None:
|
cx = self.lllite_conditioning1(self.cond_image)
|
||||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
|
||||||
cx = self.cond_emb
|
|
||||||
|
|
||||||
cx = torch.cat([cx, self.down(x)], dim=1)
|
cx = torch.cat([cx, self.down(x)], dim=1)
|
||||||
cx = self.mid(cx)
|
cx = self.mid(cx)
|
||||||
|
|||||||
@@ -289,6 +289,9 @@ def train(args):
|
|||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
if isinstance(unet, DDP):
|
||||||
|
unet._set_static_graph() # avoid error for multiple use of the parameter
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user