Skip to content

e8

Compare
Choose a tag to compare
@kyegomez kyegomez released this 25 May 03:27
· 34 commits to main since this release

The provided code for the Hutchinson estimator assumes that the input tensors are 1D. However, in many network architectures, the parameters can be multi-dimensional tensors. To handle this case, we need to modify the Hutchinson estimator to compute the dot product and Hessian-vector product correctly for multi-dimensional tensors.

class HutchinsonEstimator(HessianEstimator):
def estimate(self, p, grad):
u = torch.randn_like(grad)
grad_dot_u = torch.sum(grad * u)
hessian_vector_product = torch.autograd.grad(grad_dot_u, p, retain_graph=True)[0]
return u * hessian_vector_product