From 6e279730cf476230a79fa0c10568047f1d7753f3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 22 Jan 2023 10:44:29 +0900 Subject: [PATCH] Fix weights checking script to use float32 --- networks/check_lora_weights.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 1140e3b3..4ee3f575 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -15,12 +15,13 @@ def main(file): keys = list(sd.keys()) for key in keys: - if 'lora_up' in key: + if 'lora_up' in key or 'lora_down' in key: values.append((key, sd[key])) - print(f"number of LoRA-up modules: {len(values)}") + print(f"number of LoRA modules: {len(values)}") for key, value in values: - print(f"{key},{torch.mean(torch.abs(value))}") + value = value.to(torch.float32) + print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") if __name__ == '__main__':