update dataset to return size, refactor ctrlnet ds

This commit is contained in:
Kohya S
2023-06-24 17:56:02 +09:00
parent f7f762c676
commit 9e9df2b501
3 changed files with 333 additions and 304 deletions

View File

@@ -1116,13 +1116,15 @@ if __name__ == "__main__":
# 使用メモリ量確認用の疑似学習ループ
print("preparing optimizer")
import bitsandbytes
import transformers
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
# import bitsandbytes
# optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
import transformers
optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
scaler = torch.cuda.amp.GradScaler(enabled=True)
@@ -1133,7 +1135,7 @@ if __name__ == "__main__":
for step in range(steps):
print(f"step {step}")
x = torch.randn(batch_size, 4, 128, 128).cuda() # 512x512
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
ctx = torch.randn(batch_size, 77, 2048).cuda()
y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()