From 52f344a37e266b03ea4bb3b7ce919d9be6af041b Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Wed, 13 Apr 2022 18:47:14 -0700 Subject: [PATCH] Register LinearRegressionUCB attribute tensors as buffers (#629) Summary: Pull Request resolved: https://github.com/facebookresearch/ReAgent/pull/629 The attributes weren't registered properly, so they weren't pushed to the device when `model.to(device)` was called Reviewed By: soudia Differential Revision: D35560710 fbshipit-source-id: 67492e7f64829750e395bdec85e04b7fb6fff04c --- reagent/models/linear_regression.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/reagent/models/linear_regression.py b/reagent/models/linear_regression.py index 6725757db..9213421e4 100644 --- a/reagent/models/linear_regression.py +++ b/reagent/models/linear_regression.py @@ -54,15 +54,13 @@ def __init__( self.input_dim = input_dim self.predict_ucb = predict_ucb self.ucb_alpha = ucb_alpha - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.A = l2_reg_lambda * torch.eye(self.input_dim, device=device) - self.b = torch.zeros(self.input_dim, device=device) - self.coefs = torch.zeros(self.input_dim, device=device) - self.inv_A = torch.zeros(self.input_dim, self.input_dim, device=device) - self.coefs_valid_for_A = -torch.ones_like( - self.A, device=device + # pyre-ignore + self.register_buffer("A", l2_reg_lambda * torch.eye(self.input_dim)) + self.register_buffer("b", torch.zeros(self.input_dim)) + self.register_buffer("coefs", torch.zeros(self.input_dim)) + self.register_buffer("inv_A", torch.zeros(self.input_dim, self.input_dim)) + self.register_buffer( + "coefs_valid_for_A", -torch.ones((self.input_dim, self.input_dim)) ) # value of A matrix for which self.coefs were estimated def input_prototype(self) -> torch.Tensor: