e8
The provided code for the Hutchinson estimator assumes that the input tensors are 1D. However, in many network architectures, the parameters can be multi-dimensional tensors. To handle this case, we need to modify the Hutchinson estimator to compute the dot product and Hessian-vector product correctly for multi-dimensional tensors.
class HutchinsonEstimator(HessianEstimator):
def estimate(self, p, grad):
u = torch.randn_like(grad)
grad_dot_u = torch.sum(grad * u)
hessian_vector_product = torch.autograd.grad(grad_dot_u, p, retain_graph=True)[0]
return u * hessian_vector_product