Merge pull request #755 from kohya-ss/dev

add lora_fa
This commit is contained in:
Kohya S
2023-08-13 15:20:49 +09:00
committed by GitHub
3 changed files with 1296 additions and 2 deletions

View File

@@ -22,7 +22,11 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__
The feature of SDXL training is now available in sdxl branch as an experimental feature. The feature of SDXL training is now available in sdxl branch as an experimental feature.
Aug 12, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version. Aug 13, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
Aug 12, 2023: Following are the changes from the previous version.
- The default value of noise offset when omitted has been changed to 0 from 0.0357. - The default value of noise offset when omitted has been changed to 0 from 0.0357.
- The different learning rates for each U-Net block are now supported. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`. - The different learning rates for each U-Net block are now supported. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.

1241
networks/lora_fa.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -371,6 +371,8 @@ class PipelineLike:
width: int = 1024, width: int = 1024,
original_height: int = None, original_height: int = None,
original_width: int = None, original_width: int = None,
original_height_negative: int = None,
original_width_negative: int = None,
crop_top: int = 0, crop_top: int = 0,
crop_left: int = 0, crop_left: int = 0,
num_inference_steps: int = 50, num_inference_steps: int = 50,
@@ -505,15 +507,22 @@ class PipelineLike:
original_height = height original_height = height
if original_width is None: if original_width is None:
original_width = width original_width = width
if original_height_negative is None:
original_height_negative = original_height
if original_width_negative is None:
original_width_negative = original_width
if crop_top is None: if crop_top is None:
crop_top = 0 crop_top = 0
if crop_left is None: if crop_left is None:
crop_left = 0 crop_left = 0
emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
uc_emb1 = sdxl_train_util.get_timestep_embedding(
torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256
)
emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
uc_vector = c_vector.clone().to(self.device, dtype=text_embeddings.dtype) uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1)
c_vector = torch.cat([text_pool, c_vector], dim=1) c_vector = torch.cat([text_pool, c_vector], dim=1)
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
@@ -1260,6 +1269,8 @@ class BatchDataExt(NamedTuple):
height: int height: int
original_width: int original_width: int
original_height: int original_height: int
original_width_negative: int
original_height_negative: int
crop_left: int crop_left: int
crop_top: int crop_top: int
steps: int steps: int
@@ -1820,6 +1831,8 @@ def main(args):
original_width_1st = scale_and_round(ext.original_width) original_width_1st = scale_and_round(ext.original_width)
original_height_1st = scale_and_round(ext.original_height) original_height_1st = scale_and_round(ext.original_height)
original_width_negative_1st = scale_and_round(ext.original_width_negative)
original_height_negative_1st = scale_and_round(ext.original_height_negative)
crop_left_1st = scale_and_round(ext.crop_left) crop_left_1st = scale_and_round(ext.crop_left)
crop_top_1st = scale_and_round(ext.crop_top) crop_top_1st = scale_and_round(ext.crop_top)
@@ -1830,6 +1843,8 @@ def main(args):
height_1st, height_1st,
original_width_1st, original_width_1st,
original_height_1st, original_height_1st,
original_width_negative_1st,
original_height_negative_1st,
crop_left_1st, crop_left_1st,
crop_top_1st, crop_top_1st,
args.highres_fix_steps, args.highres_fix_steps,
@@ -1897,6 +1912,8 @@ def main(args):
height, height,
original_width, original_width,
original_height, original_height,
original_width_negative,
original_height_negative,
crop_left, crop_left,
crop_top, crop_top,
steps, steps,
@@ -2020,6 +2037,8 @@ def main(args):
width, width,
original_height, original_height,
original_width, original_width,
original_height_negative,
original_width_negative,
crop_top, crop_top,
crop_left, crop_left,
steps, steps,
@@ -2060,6 +2079,8 @@ def main(args):
metadata.add_text("clip-prompt", clip_prompt) metadata.add_text("clip-prompt", clip_prompt)
metadata.add_text("original-height", str(original_height)) metadata.add_text("original-height", str(original_height))
metadata.add_text("original-width", str(original_width)) metadata.add_text("original-width", str(original_width))
metadata.add_text("original-height-negative", str(original_height_negative))
metadata.add_text("original-width-negative", str(original_width_negative))
metadata.add_text("crop-top", str(crop_top)) metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left)) metadata.add_text("crop-left", str(crop_left))
@@ -2123,6 +2144,8 @@ def main(args):
height = args.H height = args.H
original_width = args.original_width original_width = args.original_width
original_height = args.original_height original_height = args.original_height
original_width_negative = args.original_width_negative
original_height_negative = args.original_height_negative
crop_top = args.crop_top crop_top = args.crop_top
crop_left = args.crop_left crop_left = args.crop_left
scale = args.scale scale = args.scale
@@ -2165,6 +2188,18 @@ def main(args):
print(f"original height: {original_height}") print(f"original height: {original_height}")
continue continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m:
original_width_negative = int(m.group(1))
print(f"original width negative: {original_width_negative}")
continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m:
original_height_negative = int(m.group(1))
print(f"original height negative: {original_height_negative}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE) m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m: if m:
crop_top = int(m.group(1)) crop_top = int(m.group(1))
@@ -2301,6 +2336,8 @@ def main(args):
height, height,
original_width, original_width,
original_height, original_height,
original_width_negative,
original_height_negative,
crop_left, crop_left,
crop_top, crop_top,
steps, steps,
@@ -2367,6 +2404,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値"
) )
parser.add_argument(
"--original_height_negative",
type=int,
default=None,
help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値",
)
parser.add_argument(
"--original_width_negative",
type=int,
default=None,
help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値",
)
parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値")
parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ")