mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Refactor memory cleaning into a single function
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
# XXX dropped option: hypernetwork training
|
# XXX dropped option: hypernetwork training
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
@@ -11,6 +10,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -158,9 +158,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ import diffusers
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -888,8 +889,7 @@ class PipelineLike:
|
|||||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||||
init_latents = init_latent_dist.sample(generator=generator)
|
init_latents = init_latent_dist.sample(generator=generator)
|
||||||
else:
|
else:
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
init_latents = []
|
init_latents = []
|
||||||
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
|
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
|
||||||
init_latent_dist = self.vae.encode(
|
init_latent_dist = self.vae.encode(
|
||||||
@@ -1047,8 +1047,7 @@ class PipelineLike:
|
|||||||
if vae_batch_size >= batch_size:
|
if vae_batch_size >= batch_size:
|
||||||
image = self.vae.decode(latents).sample
|
image = self.vae.decode(latents).sample
|
||||||
else:
|
else:
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
images = []
|
images = []
|
||||||
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||||
images.append(
|
images.append(
|
||||||
|
|||||||
9
library/device_utils.py
Normal file
9
library/device_utils.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def clean_memory():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -8,6 +7,7 @@ from accelerate import init_empty_weights
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||||
|
|
||||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||||
@@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
|||||||
unet.to(accelerator.device)
|
unet.to(accelerator.device)
|
||||||
vae.to(accelerator.device)
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
gc.collect()
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||||
import gc
|
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -67,6 +66,7 @@ import library.sai_model_spec as sai_model_spec
|
|||||||
|
|
||||||
# from library.attention_processors import FlashAttnProcessor
|
# from library.attention_processors import FlashAttnProcessor
|
||||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
|
|
||||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||||
@@ -2278,8 +2278,7 @@ def cache_batch_latents(
|
|||||||
info.latents_flipped = flipped_latent
|
info.latents_flipped = flipped_latent
|
||||||
|
|
||||||
# FIXME this slows down caching a lot, specify this as an option
|
# FIXME this slows down caching a lot, specify this as an option
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def cache_batch_text_encoder_outputs(
|
def cache_batch_text_encoder_outputs(
|
||||||
@@ -4006,8 +4005,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
|||||||
unet.to(accelerator.device)
|
unet.to(accelerator.device)
|
||||||
vae.to(accelerator.device)
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
gc.collect()
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||||
@@ -4816,7 +4814,7 @@ def sample_images_common(
|
|||||||
|
|
||||||
# clear pipeline and cache to reduce vram usage
|
# clear pipeline and cache to reduce vram usage
|
||||||
del pipeline
|
del pipeline
|
||||||
torch.cuda.empty_cache()
|
clean_memory()
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
if cuda_rng_state is not None:
|
if cuda_rng_state is not None:
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import diffusers
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -640,8 +641,7 @@ class PipelineLike:
|
|||||||
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
|
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
|
||||||
init_latents = init_latent_dist.sample(generator=generator)
|
init_latents = init_latent_dist.sample(generator=generator)
|
||||||
else:
|
else:
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
init_latents = []
|
init_latents = []
|
||||||
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
|
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
|
||||||
init_latent_dist = self.vae.encode(
|
init_latent_dist = self.vae.encode(
|
||||||
@@ -780,8 +780,7 @@ class PipelineLike:
|
|||||||
if vae_batch_size >= batch_size:
|
if vae_batch_size >= batch_size:
|
||||||
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||||
else:
|
else:
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
images = []
|
images = []
|
||||||
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||||
images.append(
|
images.append(
|
||||||
@@ -796,8 +795,7 @@ class PipelineLike:
|
|||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
# image = self.numpy_to_pil(image)
|
# image = self.numpy_to_pil(image)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# training with captions
|
# training with captions
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
@@ -11,6 +10,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -252,9 +252,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -407,8 +405,7 @@ def train(args):
|
|||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
else:
|
||||||
# make sure Text Encoders are on GPU
|
# make sure Text Encoders are on GPU
|
||||||
text_encoder1.to(accelerator.device)
|
text_encoder1.to(accelerator.device)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
|
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -15,6 +14,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -164,9 +164,7 @@ def train(args):
|
|||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
)
|
)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -291,8 +289,7 @@ def train(args):
|
|||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
else:
|
||||||
# make sure Text Encoders are on GPU
|
# make sure Text Encoders are on GPU
|
||||||
text_encoder1.to(accelerator.device)
|
text_encoder1.to(accelerator.device)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -12,6 +11,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -163,9 +163,7 @@ def train(args):
|
|||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
)
|
)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -264,8 +262,7 @@ def train(args):
|
|||||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||||
text_encoder1.to("cpu", dtype=torch.float32)
|
text_encoder1.to("cpu", dtype=torch.float32)
|
||||||
text_encoder2.to("cpu", dtype=torch.float32)
|
text_encoder2.to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
else:
|
else:
|
||||||
# make sure Text Encoders are on GPU
|
# make sure Text Encoders are on GPU
|
||||||
text_encoder1.to(accelerator.device)
|
text_encoder1.to(accelerator.device)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -65,8 +66,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
org_unet_device = unet.device
|
org_unet_device = unet.device
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
unet.to("cpu")
|
unet.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
@@ -81,8 +81,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if not args.lowram:
|
if not args.lowram:
|
||||||
print("move vae and unet back to original device")
|
print("move vae and unet back to original device")
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -12,6 +11,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -219,9 +219,7 @@ def train(args):
|
|||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
)
|
)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
# DreamBooth training
|
# DreamBooth training
|
||||||
# XXX dropped option: fine_tune
|
# XXX dropped option: fine_tune
|
||||||
|
|
||||||
import gc
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
@@ -12,6 +11,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -138,9 +138,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -14,6 +13,7 @@ from tqdm import tqdm
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -266,9 +266,7 @@ class NetworkTrainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
@@ -8,6 +7,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -363,9 +363,7 @@ class TextualInversionTrainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
@@ -9,6 +8,7 @@ from multiprocessing import Value
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library.device_utils import clean_memory
|
||||||
from library.ipex_interop import init_ipex
|
from library.ipex_interop import init_ipex
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -286,9 +286,7 @@ def train(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
if torch.cuda.is_available():
|
clean_memory()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user