Skip to content

Commit

Permalink
Register LinearRegressionUCB attribute tensors as buffers (facebookre…
Browse files Browse the repository at this point in the history
…search#629)

Summary:
Pull Request resolved: facebookresearch#629

The attributes weren't registered properly, so they weren't pushed to the device when `model.to(device)` was called

Reviewed By: soudia

Differential Revision: D35560710

fbshipit-source-id: 67492e7f64829750e395bdec85e04b7fb6fff04c
  • Loading branch information
alexnikulkov authored and facebook-github-bot committed Apr 14, 2022
1 parent abc08f7 commit 52f344a
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions reagent/models/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,13 @@ def __init__(
self.input_dim = input_dim
self.predict_ucb = predict_ucb
self.ucb_alpha = ucb_alpha

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.A = l2_reg_lambda * torch.eye(self.input_dim, device=device)
self.b = torch.zeros(self.input_dim, device=device)
self.coefs = torch.zeros(self.input_dim, device=device)
self.inv_A = torch.zeros(self.input_dim, self.input_dim, device=device)
self.coefs_valid_for_A = -torch.ones_like(
self.A, device=device
# pyre-ignore
self.register_buffer("A", l2_reg_lambda * torch.eye(self.input_dim))
self.register_buffer("b", torch.zeros(self.input_dim))
self.register_buffer("coefs", torch.zeros(self.input_dim))
self.register_buffer("inv_A", torch.zeros(self.input_dim, self.input_dim))
self.register_buffer(
"coefs_valid_for_A", -torch.ones((self.input_dim, self.input_dim))
) # value of A matrix for which self.coefs were estimated

def input_prototype(self) -> torch.Tensor:
Expand Down

0 comments on commit 52f344a

Please sign in to comment.