add sdxl fine-tuning and LoRA

This commit is contained in:
Kohya S
2023-06-26 08:07:24 +09:00
parent 9e9df2b501
commit 747af145ed
11 changed files with 2442 additions and 754 deletions

View File

@@ -1069,8 +1069,8 @@ class SdxlUNet2DConditionModel(nn.Module):
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert y.shape[0] == x.shape[0]
assert x.dtype == y.dtype
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == self.dtype
emb = emb + self.label_emb(y)
@@ -1105,6 +1105,8 @@ class SdxlUNet2DConditionModel(nn.Module):
if __name__ == "__main__":
import time
print("create unet")
unet = SdxlUNet2DConditionModel()
@@ -1132,8 +1134,11 @@ if __name__ == "__main__":
print("start training")
steps = 10
batch_size = 1
for step in range(steps):
print(f"step {step}")
if step == 1:
time_start = time.perf_counter()
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
@@ -1149,3 +1154,6 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")