This commit is contained in:
Kohya S
2023-07-09 14:12:05 +09:00
4 changed files with 2562 additions and 4 deletions

View File

@@ -33,6 +33,7 @@ Summary of the feature:
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
- `--weighted_captions` option is not supported yet.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
`requirements.txt` is updated to support SDXL training.
@@ -47,6 +48,7 @@ Summary of the feature:
- The LoRA training can be done with 12GB GPU memory.
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```
@@ -57,6 +59,12 @@ lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### TODO
- [ ] Support Textual Inversion training.
- [ ] Support `--weighted_captions` option.
- [ ] Change `--output_config` option to continue the training.
## About requirements.txt
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)

View File

@@ -182,7 +182,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = text_model2.load_state_dict(converted_sd)
print("text encoder2:", info2)
print("text encoder 2:", info2)
# prepare vae
print("building VAE")

View File

@@ -98,13 +98,16 @@ class WrapperTokenizer:
return SimpleNamespace(**{"input_ids": input_ids})
# for weighted prompt
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
assert isinstance(text, str), f"input must be str: {text}"
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
# find eos
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
input_ids = input_ids[:, : eos_index + 1] # include eos
eos_index = (input_ids == self.eos_token_id).nonzero().max()
input_ids = input_ids[: eos_index + 1] # include eos
return SimpleNamespace(**{"input_ids": input_ids})
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
original_path = TOKENIZER_PATH

2547
sdxl_gen_img.py Normal file

File diff suppressed because it is too large Load Diff