mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
make forward/backward pathes same ref #1363
This commit is contained in:
@@ -7,8 +7,10 @@ from typing import Optional, List, Type
|
|||||||
import torch
|
import torch
|
||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||||
@@ -103,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR):
|
|||||||
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
||||||
|
|
||||||
self.cond_image = None
|
self.cond_image = None
|
||||||
self.cond_emb = None
|
|
||||||
|
|
||||||
def set_cond_image(self, cond_image):
|
def set_cond_image(self, cond_image):
|
||||||
self.cond_image = cond_image
|
self.cond_image = cond_image
|
||||||
self.cond_emb = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
if self.cond_emb is None:
|
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
|
||||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
|
||||||
cx = self.cond_emb
|
|
||||||
|
|
||||||
# reshape / b,c,h,w -> b,h*w,c
|
# reshape / b,c,h,w -> b,h*w,c
|
||||||
n, c, h, w = cx.shape
|
n, c, h, w = cx.shape
|
||||||
@@ -159,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D):
|
|||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
if self.cond_emb is None:
|
cx = self.lllite_conditioning1(self.cond_image)
|
||||||
self.cond_emb = self.lllite_conditioning1(self.cond_image)
|
|
||||||
cx = self.cond_emb
|
|
||||||
|
|
||||||
cx = torch.cat([cx, self.down(x)], dim=1)
|
cx = torch.cat([cx, self.down(x)], dim=1)
|
||||||
cx = self.mid(cx)
|
cx = self.mid(cx)
|
||||||
|
|||||||
Reference in New Issue
Block a user