gesen2egee
cdb2d9c516
Update train_network.py
2024-08-04 17:36:34 +08:00
gesen2egee
aa850aa531
Update train_network.py
2024-08-04 17:34:20 +08:00
gesen2egee
f6dbf7c419
Update train_network.py
2024-08-04 15:18:53 +08:00
gesen2egee
a593e837f3
Update train_network.py
2024-08-04 15:17:30 +08:00
gesen2egee
b9bdd10129
Update train_network.py
2024-08-04 15:11:26 +08:00
gesen2egee
31507b9901
Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results)
2024-08-02 13:15:21 +08:00
gesen2egee
086f6000f2
Merge branch 'main' into val
2024-04-11 01:14:46 +08:00
Kohya S
d30ebb205c
update readme, add metadata for network module
2024-04-07 14:58:17 +09:00
kabachuha
90b18795fc
Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption ( #1228 )
...
* add huber loss and huber_c compute to train_util
* add reduction modes
* add huber_c retrieval from timestep getter
* move get timesteps and huber to own function
* add conditional loss to all training scripts
* add cond loss to train network
* add (scheduled) huber_loss to args
* fixup twice timesteps getting
* PHL-schedule should depend on noise scheduler's num timesteps
* *2 multiplier to huber loss cause of 1/2 a^2 conv.
The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another
* add option for smooth l1 (huber / delta)
* unify huber scheduling
* add snr huber scheduler
---------
Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com >
2024-04-07 13:54:21 +09:00
ykume
cd587ce62c
verify command line args if wandb is enabled
2024-04-05 08:23:03 +09:00
Kohya S
2258a1b753
add save/load hook to remove U-Net/TEs from state
2024-03-31 15:50:35 +09:00
Kohya S
c86e356013
Merge branch 'dev' into dataset-cache
2024-03-26 19:43:40 +09:00
Kohya S
ab1e389347
Merge branch 'dev' into masked-loss
2024-03-26 19:39:30 +09:00
Kohya S
a2b8531627
make each script consistent, fix to work w/o DeepSpeed
2024-03-25 22:28:46 +09:00
Kohya S
025347214d
refactor metadata caching for DreamBooth dataset
2024-03-24 18:09:32 +09:00
Kohaku-Blueleaf
ae97c8bfd1
[Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization ( #1178 )
...
* support meta cached dataset
* add cache meta scripts
* random ip_noise_gamma strength
* random noise_offset strength
* use correct settings for parser
* cache path/caption/size only
* revert mess up commit
* revert mess up commit
* Update requirements.txt
* Add arguments for meta cache.
* remove pickle implementation
* Return sizes when enable cache
---------
Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com >
2024-03-24 15:40:18 +09:00
Kohya S
fbb98f144e
Merge branch 'dev' into deep-speed
2024-03-20 18:15:26 +09:00
Kohya S
9b6b39f204
Merge branch 'dev' into masked-loss
2024-03-20 18:14:36 +09:00
Kohya S
bf6cd4b9da
Merge pull request #1168 from gesen2egee/save_state_on_train_end
...
Save state on train end
2024-03-20 18:02:13 +09:00
Kohya S
3419c3de0d
common masked loss func, apply to all training script
2024-03-17 19:30:20 +09:00
gesen2egee
d05965dbad
Update train_network.py
2024-03-13 18:33:51 +08:00
gesen2egee
5d7ed0dff0
Merge remote-tracking branch 'kohya-ss/dev' into val
2024-03-13 18:00:49 +08:00
gesen2egee
bd7e2295b7
fix
2024-03-13 17:54:21 +08:00
gesen2egee
d282c45002
Update train_network.py
2024-03-11 23:56:09 +08:00
gesen2egee
a6c41c6bea
Update train_network.py
2024-03-11 19:23:48 +08:00
gesen2egee
63e58f78e3
Update train_network.py
2024-03-11 19:15:55 +08:00
gesen2egee
befbec5335
Update train_network.py
2024-03-11 18:47:04 +08:00
gesen2egee
a51723cc2a
fix timesteps
2024-03-11 09:42:58 +08:00
gesen2egee
095b8035e6
save state on train end
2024-03-10 23:33:38 +08:00
gesen2egee
47359b8fac
Update train_network.py
2024-03-10 20:17:40 +08:00
gesen2egee
923b761ce3
Update train_network.py
2024-03-10 20:01:40 +08:00
gesen2egee
78cfb01922
improve
2024-03-10 18:55:48 +08:00
gesen2egee
b558a5b73d
val
2024-03-10 04:37:16 +08:00
Kohya S
e3ccf8fbf7
make deepspeed_utils
2024-02-27 21:30:46 +09:00
Kohya S
eefb3cc1e7
Merge branch 'deep-speed' into deepspeed
2024-02-27 18:57:42 +09:00
Kohya S
4a5546d40e
fix typo
2024-02-26 23:39:56 +09:00
Kohya S
f2c727fc8c
add minimal impl for masked loss
2024-02-26 23:19:58 +09:00
Kohya S
577e9913ca
add some new dataset settings
2024-02-26 20:01:25 +09:00
Kohya S
f4132018c5
fix to work with cpu_count() == 1 closes #1134
2024-02-24 19:25:31 +09:00
BootsofLagrangian
4d5186d1cf
refactored codes, some function moved into train_utils.py
2024-02-22 16:20:53 +09:00
Kohya S
baa0e97ced
Merge branch 'dev' into dev_device_support
2024-02-17 11:54:07 +09:00
Kohya S
93bed60762
fix to work --console_log_xxx options
2024-02-12 14:49:29 +09:00
Kohya S
358ca205a3
Merge branch 'dev' into dev_device_support
2024-02-12 13:01:54 +09:00
Kohya S
e24d9606a2
add clean_memory_on_device and use it from training
2024-02-12 11:10:52 +09:00
Kohya S
055f02e1e1
add logging args for training scripts
2024-02-08 21:16:42 +09:00
BootsofLagrangian
62556619bd
fix full_fp16 compatible and train_step
2024-02-07 16:42:05 +09:00
BootsofLagrangian
7d2a9268b9
apply offloading method runable for all trainer
2024-02-05 22:42:06 +09:00
BootsofLagrangian
4295f91dcd
fix all trainer about vae
2024-02-05 20:19:56 +09:00
Yuta Hayashibe
5f6bf29e52
Replace print with logger if they are logs ( #905 )
...
* Add get_my_logger()
* Use logger instead of print
* Fix log level
* Removed line-breaks for readability
* Use setup_logging()
* Add rich to requirements.txt
* Make simple
* Use logger instead of print
---------
Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com >
2024-02-04 18:14:34 +09:00
BootsofLagrangian
dfe08f395f
support deepspeed
2024-02-04 03:12:42 +09:00