mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Compare commits
1 Commits
no-grad-bl
...
vae_batch_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dbd835ee4b |
3
.github/FUNDING.yml
vendored
3
.github/FUNDING.yml
vendored
@@ -1,3 +0,0 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: kohya-ss
|
||||
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
@@ -12,9 +12,6 @@ on:
|
||||
- dev
|
||||
- sd3
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -43,7 +40,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
3
.github/workflows/typos.yml
vendored
3
.github/workflows/typos.yml
vendored
@@ -12,9 +12,6 @@ on:
|
||||
- synchronize
|
||||
- reopened
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
23
README.md
23
README.md
@@ -9,22 +9,11 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
- [FLUX.1 training](#flux1-training)
|
||||
- [SD3 training](#sd3-training)
|
||||
|
||||
### Recent Updates
|
||||
|
||||
May 1, 2025:
|
||||
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
|
||||
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
Apr 27, 2025:
|
||||
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
|
||||
- See [here](#sample-image-generation-during-training) for details.
|
||||
- If you have any issues with this, please let us know.
|
||||
|
||||
Apr 6, 2025:
|
||||
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
|
||||
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
|
||||
@@ -881,14 +870,6 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
## DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
|
||||
To install DeepSpeed, run the following command in your activated virtual environment:
|
||||
|
||||
```bash
|
||||
pip install deepspeed==0.16.7
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
@@ -1363,13 +1344,11 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
|
||||
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
@@ -97,19 +97,15 @@ def main(args):
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
args.repo_id,
|
||||
file,
|
||||
subfolder=SUB_DIR,
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
cache_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
|
||||
@@ -42,20 +42,19 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
with torch.no_grad():
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
@@ -76,14 +75,14 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device(device)
|
||||
synchronize_device()
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device(device)
|
||||
synchronize_device()
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
|
||||
@@ -5,8 +5,6 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
from .device_utils import get_preferred_device
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -96,7 +94,6 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -125,56 +122,18 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def __wrap_model_with_torch_autocast(self, model):
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
|
||||
else:
|
||||
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
try:
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -40,7 +40,7 @@ def sample_images(
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement=None,
|
||||
controlnet=None,
|
||||
controlnet=None
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
@@ -101,7 +101,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
controlnet
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
@@ -125,7 +125,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
controlnet
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
@@ -147,16 +147,14 @@ def sample_image_inference(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet,
|
||||
controlnet
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
# TODO refactor variable names
|
||||
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
|
||||
emb_guidance_scale = prompt_dict.get("scale", 3.5)
|
||||
scale = prompt_dict.get("scale", 3.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
@@ -164,8 +162,8 @@ def sample_image_inference(
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
# if negative_prompt is not None:
|
||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
@@ -175,21 +173,16 @@ def sample_image_inference(
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
logger.info(f"prompt: {prompt}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"CFG scale: {cfg_scale}")
|
||||
logger.info(f"scale: {scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
@@ -198,37 +191,26 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
def encode_prompt(prpt):
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prpt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
return text_encoder_conds
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
|
||||
# encode negative prompts
|
||||
if cfg_scale != 1.0:
|
||||
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
|
||||
neg_t5_attn_mask = (
|
||||
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
|
||||
)
|
||||
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
|
||||
else:
|
||||
neg_cond = None
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -253,20 +235,7 @@ def sample_image_inference(
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(
|
||||
flux,
|
||||
noise,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance=emb_guidance_scale,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
controlnet=controlnet,
|
||||
controlnet_img=controlnet_image,
|
||||
neg_cond=neg_cond,
|
||||
)
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -336,24 +305,22 @@ def denoise(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor, # t5_out
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor, # l_pooled
|
||||
vec: torch.Tensor,
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
do_cfg = neg_cond is not None
|
||||
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
if controlnet is not None:
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=img,
|
||||
@@ -369,48 +336,20 @@ def denoise(
|
||||
else:
|
||||
block_samples = None
|
||||
block_single_samples = None
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
if not do_cfg:
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
else:
|
||||
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
|
||||
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
|
||||
# TODO is it ok to use the same block samples for both cond and uncond?
|
||||
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
|
||||
block_single_samples = (
|
||||
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
|
||||
)
|
||||
|
||||
nc_c_pred = model(
|
||||
img=torch.cat([img, img], dim=0),
|
||||
img_ids=torch.cat([img_ids, img_ids], dim=0),
|
||||
txt=torch.cat([neg_t5_out, txt], dim=0),
|
||||
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
|
||||
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=nc_c_t5_attn_mask,
|
||||
)
|
||||
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||
pred = neg_pred + (pred - neg_pred) * cfg_scale
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
@@ -494,7 +433,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
else:
|
||||
@@ -519,7 +458,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
@@ -630,7 +569,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)",
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
|
||||
@@ -1060,11 +1060,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
else:
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
@@ -5498,11 +5495,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
|
||||
from accelerate import DistributedType
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
@@ -6186,11 +6178,6 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # guidance scale
|
||||
prompt_dict["guidance_scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
prompt_dict["negative_prompt"] = m.group(1)
|
||||
|
||||
@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor | None:
|
||||
def grad_norms(self) -> Tensor:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||
|
||||
def weight_norms(self) -> Tensor | None:
|
||||
def weight_norms(self) -> Tensor:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||
|
||||
def combined_weight_norms(self) -> Tensor | None:
|
||||
def combined_weight_norms(self) -> Tensor:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
|
||||
@@ -6,4 +6,3 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -640,14 +640,23 @@ def train(args):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size]
|
||||
for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
list_latents.append(
|
||||
vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
||||
)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
|
||||
@@ -1443,13 +1443,11 @@ class NetworkTrainer:
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
if hasattr(network, "weight_norms"):
|
||||
mean_norm = network.weight_norms().mean().item()
|
||||
mean_grad_norm = network.grad_norms().mean().item()
|
||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||
weight_norms = network.weight_norms()
|
||||
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user