mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into sdxl
This commit is contained in:
@@ -33,6 +33,7 @@ Summary of the feature:
|
||||
- The image generation during training is now available. However, the VAE for SDXL seems to produce NaNs in some cases when using `fp16`. The images will be black. Currently, the NaNs cannot be avoided even with `--no_half_vae` option. It works with `bf16` or without mixed precision.
|
||||
- `--weighted_captions` option is not supported yet.
|
||||
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
|
||||
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
|
||||
|
||||
`requirements.txt` is updated to support SDXL training.
|
||||
|
||||
@@ -47,6 +48,7 @@ Summary of the feature:
|
||||
- The LoRA training can be done with 12GB GPU memory.
|
||||
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
|
||||
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
|
||||
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
|
||||
|
||||
Example of the optimizer settings for Adafactor with the fixed learning rate:
|
||||
```
|
||||
@@ -57,6 +59,12 @@ lr_warmup_steps = 100
|
||||
learning_rate = 4e-7 # SDXL original learning rate
|
||||
```
|
||||
|
||||
### TODO
|
||||
|
||||
- [ ] Support Textual Inversion training.
|
||||
- [ ] Support `--weighted_captions` option.
|
||||
- [ ] Change `--output_config` option to continue the training.
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
|
||||
|
||||
@@ -182,7 +182,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
||||
|
||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||
info2 = text_model2.load_state_dict(converted_sd)
|
||||
print("text encoder2:", info2)
|
||||
print("text encoder 2:", info2)
|
||||
|
||||
# prepare vae
|
||||
print("building VAE")
|
||||
|
||||
@@ -98,13 +98,16 @@ class WrapperTokenizer:
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
# for weighted prompt
|
||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)
|
||||
assert isinstance(text, str), f"input must be str: {text}"
|
||||
|
||||
input_ids = open_clip.tokenize(text, context_length=self.model_max_length)[0] # tokenizer returns list
|
||||
|
||||
# find eos
|
||||
eos_index = (input_ids == self.eos_token_id).nonzero()[0].max() # max index of each batch
|
||||
input_ids = input_ids[:, : eos_index + 1] # include eos
|
||||
eos_index = (input_ids == self.eos_token_id).nonzero().max()
|
||||
input_ids = input_ids[: eos_index + 1] # include eos
|
||||
return SimpleNamespace(**{"input_ids": input_ids})
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
print("prepare tokenizers")
|
||||
original_path = TOKENIZER_PATH
|
||||
|
||||
2547
sdxl_gen_img.py
Normal file
2547
sdxl_gen_img.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user