Skip to content

evanatyourservice/distributed_kron

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Distributed PSGD Kron

For original PSGD repo and some great resources, see psgd_torch.

Background: Implementation of PSGD Kron in JAX (optax-style) for distributed training. PSGD is a second-order optimizer originally created by Xi-Lin Li and further developed by Omead Pooladzandi that uses either a hessian-based or whitening-based (gg^T) preconditioner, lie groups, and online preconditioner updating to improve training convergence, generalization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo linked above for interesting details on how PSGD works and experiments using PSGD. There are also resources listed near the bottom of this readme.

distributed_kron:

The most versatile and easy-to-use PSGD optimizer is kron, which uses Kronecker-factored preconditioners. It has less hyperparameters that need tuning than adam, and can generally act as a drop-in replacement.

Distributed kron is a version of kron meant for large scale distributed training in JAX. It uses merging of dimensions, vmapping of layers, partitioning of grads, and sharding constraints to allow for easy and efficient second-order training of large models.

Installation

pip install distributed-kron

Basic Usage

FYI: Kron schedules the preconditioner update probability by default to start at 1.0 and anneal to 0.03 during the first 4k steps, so training will be slightly slower at the start but will speed up by around 4k steps.

Learning Rate: Kron usually likes a learning rate around 3x smaller than adam's. Kron does not share adam's implicit warmup, so a longer warmup schedule may be beneficial if divergence is seen early in training.

For basic usage, use distributed_kron.kron like any other optax optimizer:

from distributed_kron import kron

optimizer = kron()
opt_state = optimizer.init(params)

updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)

Advanced Usage

For advanced usage, see the kron_example.py file.

The main thing to note is that your workflow should include passing params partition specs into kron through params_partition_specs, which will be used for internal sharding constraints. Also, it is best to explicitly set preconditioner partition specs using preconditioner_partition_spec (see hyperparameters section below).

get_opt_state_partition_specs:

If you need it, there is an optimizer state partition specs helper function get_opt_state_partition_specs:

from distributed_kron import get_opt_state_partition_specs

kron_kwargs = dict(
    learning_rate=0.0003,
    weight_decay=0.01,
    scanned_layers=scanned_layers_pytree,
    params_partition_specs=params_partition_specs,
    preconditioner_partition_spec=P("fsdp", None),
)

optimizer = kron(**kron_kwargs)

opt_state_partition_specs = get_opt_state_partition_specs(
    params=train_state_shapes["params"], scale_by_kron_only=False, **kron_kwargs  # pass in kwargs
)

Hyperparameters

learning_rate: Kron usually likes a learning rate around 3x smaller than adam's. Kron does not share adam's implicit warmup, so a longer warmup schedule may be beneficial if divergence is seen early in training.

Preconditioner settings:

Preconditioner structure: For a layer with shape (256, 128), default triangular preconditioners would be shapes (256, 256) and (128, 128). However, with the following options we can also choose to make some of all of these preconditioners diagonal, which would be shapes (256,) and (128,).

Depending on how the following settings are chosen, kron can balance between memory/speed and effectiveness. Defaults lead to most precoditioners being triangular except for 1-dimensional layers and very large dimensions.

max_size_triangular: Any dimension with size above this value will have a diagonal preconditioner.

min_ndim_triangular: Any tensor with less than this number of dimensions will have all diagonal preconditioners. Default is 2, so single-dim layers like bias and scale use diagonal preconditioners.

memory_save_mode: Can be None, 'one_diag', or 'all_diag'. None is default and lets all preconditioners be triangular. 'one_diag' sets the largest or last dim per layer as diagonal using np.argsort(shape)[::-1][0]. 'all_diag' sets all preconditioners to be diagonal.

Preconditioner update frequency:

PSGD generally benefits from more preconditioner updates at the start of training, but once the preconditioner is learned it's okay to do them less often.

preconditioner_update_probability: Kron schedules preconditioner update probability by default using a schedule that works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training will be slightly slower at the start but will speed up by around 4k steps.

An easy way to adjust update frequency is to pass in your own precond_update_prob_schedule function to kron's preconditioner_update_probability hyperparameter:

from distributed_kron import kron, precond_update_prob_schedule

optimizer = kron(
    preconditioner_update_probability=precond_update_prob_schedule(
        # update precond every 20 steps
        min_prob=0.05,  # (default is 0.03)
        # update precond every step for first 1000 steps before starting to anneal
        flat_start=1000  # (default is 500)
    )
)

This is the default schedule defined in the precond_update_prob_schedule:

Default Schedule


Sharding:

If you are sharding your params, pass your params' PartitionSpecs into kron through the params_partition_specs hyperparameter. This will be used for internal sharding constraints.

To shard preconditioners, pass a PartitionSpec into the preconditioner_partition_spec hyperparameter. Best practice is to set this to something like P('fsdp', None) or P('fsdp', 'tp'). If params_partition_specs is set but preconditioner_partition_spec is None, a so-so preconditioner sharding strategy will be inferred from params_partition_specs.

Scanned layers:

If you are scanning layers in your network, kron can also scan over those arrays internally. Pass in a pytree the same structure as your params with True values indicating scanned arrays and False values indicating non-scanned arrays through the scanned_layers hyperparameter. PSGD will vmap over the first dims of those layers. If you need a more advanced scanning setup, please open an issue.

Scan instead of vmap: For very large models, the preconditioner update may use too much memory all at once when scanning, in which case you can set lax_map_scanned_layers to True and set lax_map_batch_size to a reasonable batch size for your setup (lax.map scans over batches of vmap, see JAX docs). If your net is 32 layers and you're hitting OOM during the optimizer step, you can break the model into 2 or 4 and set lax_map_batch_size to 16 or 8 respectively.


For more hyperparameter info, please see kron's docstring.

Resources

PSGD papers and resources listed from Xi-Lin's repo

  1. Xi-Lin Li. Preconditioned stochastic gradient descent, arXiv:1512.04202, 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.)

  2. Xi-Lin Li. Preconditioner on matrix Lie group for SGD, arXiv:1809.10232, 2018. (Focus on preconditioners with the affine Lie group.)

  3. Xi-Lin Li. Black box Lie group preconditioners for SGD, arXiv:2211.04422, 2022. (Mainly about the LRA preconditioner. See these supplementary materials for detailed math derivations.)

  4. Xi-Lin Li. Stochastic Hessian fittings on Lie groups, arXiv:2402.11858, 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\rm GL}(n, \mathbb{R})/R_{\rm polar}$.)

  5. Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, arXiv:2402.04553, 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.)

License

CC BY 4.0

This work is licensed under a Creative Commons Attribution 4.0 International License.

2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li

About

An implementation of PSGD Kron in JAX for distributed training in JAX or Flax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages