mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add sdxl fine-tuning and LoRA
This commit is contained in:
@@ -1069,8 +1069,8 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
assert y.shape[0] == x.shape[0]
|
||||
assert x.dtype == y.dtype
|
||||
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||
# assert x.dtype == self.dtype
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@@ -1105,6 +1105,8 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
print("create unet")
|
||||
unet = SdxlUNet2DConditionModel()
|
||||
|
||||
@@ -1132,8 +1134,11 @@ if __name__ == "__main__":
|
||||
print("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
if step == 1:
|
||||
time_start = time.perf_counter()
|
||||
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
|
||||
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
|
||||
@@ -1149,3 +1154,6 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||
|
||||
Reference in New Issue
Block a user