Skip to content

Tags: qkrclrl701/neural-tangents

Tags

v0.3.6

Toggle v0.3.6's commit message
bump version

PiperOrigin-RevId: 353442082

v0.3.5

Toggle v0.3.5's commit message
Empirical NTK speedup.

Allow to `vmap` over batch axis in `empirical_ntk_fn`. This follows from an observation that `d(vmap_x(f))/dp (p, x) == vmap_x(df/dp)(p, x)`, and most common neural networks are effectively `vmap`s over their batch axis. In experiments this seems to give ~2-260X speedup, notably by allowing to use larger batches in the direct method. For small batch sizes this should have no effect.

Further, fuse `nt.empirical_implicit_ntk_fn` and `nt.empirical_direct_ntk_fn` into a single `nt.empirical_ntk_fn` that now accepts the `implementation=1/2` argument. `nt.empirical_kernel_fn` and `nt.monte_carlo_kernel_fn` now also accept this argument. This is breaking if you were using `nt.empirical_direct_ntk_fn` (now this is `nt.empirical_ntk_fn(..., implementation=1)`.

Implementation-wise, I believe this gives the following speedups:

1) In `nt.empirical_direct_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=1)`), O(batch_size_1) time/memory improvement when constructing the Jacobian (followed by contraction, which is unchanged). I believe the most notable benefit here is increased batch size when construction the Jacobian.

2) In `nt.empirical_implicit_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=2)`, same O(batch_size_1) time/memory improvement, BUT in practice it seems to only give about 2X speedup, since this method does not gain any memory efficiency and remains O(batch_size_1 * batch_size_2 * #params).

This is inspired from discussion with schsam@ and google#30, but I'm not entirely sure how this relates to the layer-wise Jacobians idea.

Also:
- make direct method default (`implementation=1`); add suggestion when to use each.
- make stax layers preserve exact input PyTrees (e.g. tuples vs lists etc).
- small fix to `nt.empirical_direct_ntk_fn` to work with `x2=None`, and activate respective tests.
- do not raise an error (only warn) if elements of an input pytree have mismatching batch or channel axes, since this case still works in a finite case.
- fix some typos in stax tests.

Co-authored-by: Sam Schoenholz <[email protected]>
PiperOrigin-RevId: 342982475

v0.3.4

Toggle v0.3.4's commit message
Squeeze tests

PiperOrigin-RevId: 337210676

v0.3.3

Toggle v0.3.3's commit message
add papers, bump jax

PiperOrigin-RevId: 329609690

v0.3.2

Toggle v0.3.2's commit message
version bump

PiperOrigin-RevId: 322897958

v0.3.1

Toggle v0.3.1's commit message
Update readme

PiperOrigin-RevId: 320622244

v0.3.0

Toggle v0.3.0's commit message
Squeeze a flaky test to make it run in 15 minutes.

PiperOrigin-RevId: 318207007