mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
refactored codes, some function moved into train_utils.py
This commit is contained in:
29
fine_tune.py
29
fine_tune.py
@@ -243,24 +243,19 @@ def train(args):
|
|||||||
text_encoder.to(weight_dtype)
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
# wrapping model
|
training_models_dict = {}
|
||||||
import deepspeed
|
training_models_dict["unet"] = unet
|
||||||
if args.offload_optimizer_device is not None:
|
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder
|
||||||
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
|
|
||||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
||||||
accelerator.print('[DeepSpeed] building cpu_adam done.')
|
|
||||||
class DeepSpeedModel(torch.nn.Module):
|
|
||||||
def __init__(self, unet, text_encoder) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.unet = unet
|
|
||||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
|
||||||
def get_models(self):
|
|
||||||
return self.unet, self.text_encoders
|
|
||||||
ds_model = DeepSpeedModel(unet, text_encoders)
|
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
|
||||||
unet, text_encoders = ds_model.get_models() # for compatiblility
|
training_models = []
|
||||||
text_encoder = text_encoders
|
unet = ds_model.models["unet"]
|
||||||
|
training_models.append(unet)
|
||||||
|
if args.train_text_encoder:
|
||||||
|
text_encoder = ds_model.models["text_encoder"]
|
||||||
|
training_models.append(text_encoder)
|
||||||
|
|
||||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||||
if args.train_text_encoder:
|
if args.train_text_encoder:
|
||||||
|
|||||||
@@ -3959,27 +3959,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||||
deepspeed_plugin = None
|
deepspeed_plugin = prepare_deepspeed_plugin(args)
|
||||||
if args.deepspeed:
|
|
||||||
deepspeed_plugin = DeepSpeedPlugin(
|
|
||||||
zero_stage=args.zero_stage,
|
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
|
|
||||||
offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
|
||||||
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
|
|
||||||
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
|
|
||||||
)
|
|
||||||
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
|
||||||
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
|
|
||||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
|
|
||||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
|
||||||
if args.mixed_precision.lower() == "fp16":
|
|
||||||
deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0
|
|
||||||
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
|
||||||
if args.offload_optimizer_device == "cpu":
|
|
||||||
deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True
|
|
||||||
print("[DeepSpeed] full fp16 enable.")
|
|
||||||
else:
|
|
||||||
print("full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam.")
|
|
||||||
|
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
@@ -3992,6 +3972,62 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
)
|
)
|
||||||
return accelerator
|
return accelerator
|
||||||
|
|
||||||
|
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||||
|
if args.deepspeed is None: return None
|
||||||
|
try:
|
||||||
|
import deepspeed
|
||||||
|
except ImportError as e:
|
||||||
|
print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
deepspeed_plugin = DeepSpeedPlugin(
|
||||||
|
zero_stage=args.zero_stage,
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
|
||||||
|
offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
||||||
|
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
|
||||||
|
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||||
|
)
|
||||||
|
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
||||||
|
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
|
||||||
|
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
|
||||||
|
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||||
|
if args.mixed_precision.lower() == "fp16":
|
||||||
|
deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow.
|
||||||
|
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
||||||
|
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
||||||
|
deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True
|
||||||
|
print("[DeepSpeed] full fp16 enable.")
|
||||||
|
else:
|
||||||
|
print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.")
|
||||||
|
|
||||||
|
if args.offload_optimizer_device is not None:
|
||||||
|
print('[DeepSpeed] start to manually build cpu_adam.')
|
||||||
|
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
||||||
|
print('[DeepSpeed] building cpu_adam done.')
|
||||||
|
|
||||||
|
return deepspeed_plugin
|
||||||
|
|
||||||
|
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||||
|
class DeepSpeedWrapper(torch.nn.Module):
|
||||||
|
def __init__(self, **kw_models) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.models = torch.nn.ModuleDict()
|
||||||
|
|
||||||
|
for key, model in kw_models.items():
|
||||||
|
if isinstance(model, list):
|
||||||
|
model = torch.nn.ModuleList(model)
|
||||||
|
assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||||
|
self.models.update(
|
||||||
|
torch.nn.ModuleDict(
|
||||||
|
{key: model}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_models(self):
|
||||||
|
return self.models
|
||||||
|
|
||||||
|
ds_model = DeepSpeedWrapper(**models)
|
||||||
|
return ds_model
|
||||||
|
|
||||||
def prepare_dtype(args: argparse.Namespace):
|
def prepare_dtype(args: argparse.Namespace):
|
||||||
weight_dtype = torch.float32
|
weight_dtype = torch.float32
|
||||||
|
|||||||
@@ -391,28 +391,29 @@ def train(args):
|
|||||||
text_encoder2.to(weight_dtype)
|
text_encoder2.to(weight_dtype)
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
# Wrapping model for DeepSpeed
|
training_models_dict = {}
|
||||||
import deepspeed
|
if train_unet:
|
||||||
if args.offload_optimizer_device is not None:
|
training_models_dict["unet"] = unet
|
||||||
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
|
if train_text_encoder1:
|
||||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||||
accelerator.print('[DeepSpeed] building cpu_adam done.')
|
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
training_models_dict["text_encoder1"] = text_encoder1
|
||||||
class DeepSpeedModel(torch.nn.Module):
|
if train_text_encoder2:
|
||||||
def __init__(self, unet, text_encoder) -> None:
|
training_models_dict["text_encoder2"] = text_encoder2
|
||||||
super().__init__()
|
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
||||||
self.unet = unet
|
|
||||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
return self.unet, self.text_encoders
|
|
||||||
text_encoders = [text_encoder1, text_encoder2]
|
|
||||||
ds_model = DeepSpeedModel(unet, text_encoders)
|
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
|
||||||
unet, text_encoders = ds_model.get_models() # for compatiblility
|
training_models = [] # override training_models
|
||||||
text_encoder1, text_encoder2 = text_encoder = text_encoders
|
if train_unet:
|
||||||
training_models = [unet, text_encoder1, text_encoder2]
|
unet = ds_model.models["unet"]
|
||||||
|
training_models.append(unet)
|
||||||
|
if train_text_encoder1:
|
||||||
|
text_encoder1 = ds_model.models["text_encoder1"]
|
||||||
|
training_models.append(text_encoder1)
|
||||||
|
if train_text_encoder2:
|
||||||
|
text_encoder2 = ds_model.models["text_encoder2"]
|
||||||
|
training_models.append(text_encoder2)
|
||||||
|
|
||||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
|
|||||||
31
train_db.py
31
train_db.py
@@ -216,25 +216,20 @@ def train(args):
|
|||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
# wrapping model
|
training_models_dict = {}
|
||||||
import deepspeed
|
training_models_dict["unet"] = unet
|
||||||
if args.offload_optimizer_device is not None:
|
if train_text_encoder: training_models_dict["text_encoder"] = text_encoder
|
||||||
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
|
|
||||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
||||||
accelerator.print('[DeepSpeed] building cpu_adam done.')
|
|
||||||
class DeepSpeedModel(torch.nn.Module):
|
|
||||||
def __init__(self, unet, text_encoder) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.unet = unet
|
|
||||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
return self.unet, self.text_encoders
|
|
||||||
ds_model = DeepSpeedModel(unet, text_encoders)
|
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
|
||||||
unet, text_encoders = ds_model.get_models() # for compatiblility
|
training_models = []
|
||||||
text_encoder = text_encoders
|
unet = ds_model.models["unet"]
|
||||||
|
training_models.append(unet)
|
||||||
|
if train_text_encoder:
|
||||||
|
text_encoder = ds_model.models["text_encoder"]
|
||||||
|
training_models.append(text_encoder)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if train_text_encoder:
|
if train_text_encoder:
|
||||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
|||||||
@@ -410,26 +410,22 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
# wrapping model
|
training_models_dict = {}
|
||||||
import deepspeed
|
if train_unet: training_models_dict["unet"] = unet
|
||||||
if args.offload_optimizer_device is not None:
|
if train_text_encoder: training_models_dict["text_encoder"] = text_encoders
|
||||||
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
|
training_models_dict["network"] = network
|
||||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
|
||||||
accelerator.print('[DeepSpeed] building cpu_adam done.')
|
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
||||||
class DeepSpeedModel(torch.nn.Module):
|
|
||||||
def __init__(self, unet, text_encoder, network) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.unet = unet
|
|
||||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
|
||||||
self.network = network
|
|
||||||
|
|
||||||
def get_models(self):
|
|
||||||
return self.unet, self.text_encoders, self.network
|
|
||||||
ds_model = DeepSpeedModel(unet, text_encoders, network)
|
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
|
||||||
unet, text_encoders, network = ds_model.get_models() # for compatiblility
|
if train_unet: unet = ds_model.models["unet"]
|
||||||
text_encoder = text_encoders
|
if train_text_encoder:
|
||||||
|
text_encoder = ds_model.models["text_encoder"]
|
||||||
|
if len(ds_model.models["text_encoder"]) > 1:
|
||||||
|
text_encoders = text_encoder
|
||||||
|
else:
|
||||||
|
text_encoders = [text_encoder]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if train_unet:
|
if train_unet:
|
||||||
unet = accelerator.prepare(unet)
|
unet = accelerator.prepare(unet)
|
||||||
|
|||||||
Reference in New Issue
Block a user