update versions of accelerate and diffusers

This commit is contained in:
Kohya S
2023-09-24 17:46:57 +09:00
parent 20e929e27e
commit 7e736da30c
3 changed files with 22 additions and 12 deletions

View File

@@ -22,7 +22,14 @@ __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__
The feature of SDXL training is now available in sdxl branch as an experimental feature.
Sep 3, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.
Sep 24, 2023: The feature will be merged into the main branch very soon. Following are the changes from the previous version.
- `accelerate` is updated to 0.23.0, and `diffusers` is updated to 0.21.2. Please update them with the upgrade instructions below.
- Intel ARC support with IPEX is added. [#825](https://github.com/kohya-ss/sd-scripts/pull/825)
- Other changes and fixes.
- Thanks for contributions from Disty0, sdbds, jvkap, rockerBOO, Symbiomatrix and others!
Sep 3, 2023:
- ControlNet-LLLite is added. See [documentation](./docs/train_lllite_README.md) for details.
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)

View File

@@ -117,7 +117,7 @@ class LoRAModule(torch.nn.Module):
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
@@ -126,7 +126,7 @@ class LoRAModule(torch.nn.Module):
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
@@ -166,7 +166,8 @@ class LoRAModule(torch.nn.Module):
self.org_module[0].forward = self.org_forward
# forward with lora
def forward(self, x):
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
def forward(self, x, scale=1.0):
if not self.enabled:
return self.org_forward(x)
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
@@ -318,8 +319,12 @@ class LoRANetwork(torch.nn.Module):
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_linear = (
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
)
is_conv2d = (
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
@@ -359,7 +364,7 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight.")
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
# extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
@@ -368,7 +373,7 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight.")
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
# assertion
names = set()

View File

@@ -1,6 +1,6 @@
accelerate==0.19.0
accelerate==0.23.0
transformers==4.30.2
diffusers[torch]==0.18.2
diffusers[torch]==0.21.2
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
@@ -15,8 +15,6 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
# for loading Diffusers' SDXL
invisible-watermark==0.2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12