mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
maybe fix branch to run offloading
This commit is contained in:
@@ -3964,6 +3964,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
|
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||||
)
|
)
|
||||||
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
||||||
|
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
|
||||||
|
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
|
||||||
|
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
|||||||
@@ -391,6 +391,12 @@ def train(args):
|
|||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
# Wrapping model for DeepSpeed
|
# Wrapping model for DeepSpeed
|
||||||
|
import deepspeed
|
||||||
|
if args.offload_optimizer_device is not None:
|
||||||
|
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
|
||||||
|
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
||||||
|
accelerator.print('[DeepSpeed] building cpu_adam done.')
|
||||||
|
|
||||||
class DeepSpeedModel(torch.nn.Module):
|
class DeepSpeedModel(torch.nn.Module):
|
||||||
def __init__(self, unet, text_encoder) -> None:
|
def __init__(self, unet, text_encoder) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
Reference in New Issue
Block a user