Skip to content

Commit

Permalink
More (randomized) testing; bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Jul 30, 2021
1 parent 9ebcf9d commit c39d5fe
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 35 deletions.
2 changes: 1 addition & 1 deletion torch_mutual_information/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundary: Optional[torch.Te

ans = _mutual_information_forward_dispatcher(px, py, boundary, p)

print(f"p = {p}, boundary = {boundary}")
#print(f"p = {p}, boundary = {boundary}")

if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundary, p)
Expand Down
6 changes: 3 additions & 3 deletions torch_mutual_information/mutual_information_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,16 @@ std::vector<torch::Tensor> mutual_information_backward_cpu(
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_a[b][s_begin][t_begin] / ans_grad_a[b];
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
printf("Warning: mutual_information backprop: expected these numbers to be the same: %f vs. %f\n",
(float)p_a[b][s_begin][t_begin], (float)ans_grad_a[b]);
(float)p_grad_a[b][s_begin][t_begin], (float)ans_grad_a[b]);
}
}
}
}));

std::cout << "p_grad = " << p_grad;
// std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}

Expand Down
2 changes: 1 addition & 1 deletion torch_mutual_information/mutual_information_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,6 @@ mutual_information_backward_cuda(torch::Tensor px,
overwrite_ans_grad);
}
}));
std::cout << "p_grad = " << p_grad;
// std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad});
}
85 changes: 55 additions & 30 deletions torch_mutual_information/mutual_information_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,61 @@

def test_mutual_information_basic():
print("Running test_mutual_information_basic()")
for dtype in [torch.float32, torch.float64]:
px_grads = []
py_grads = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2
S = 14
T = 14
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
px.requires_grad = True
py.requires_grad = True

#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)

print("m = ", m, ", size = ", m.shape)
print("exp(m) = ", m.exp())
(m.sum() * 3).backward()
print("px_grad = ", px.grad)
print("py_grad = ", py.grad)
px_grads.append(px.grad.to('cpu'))
py_grads.append(py.grad.to('cpu'))
if not torch.allclose(px_grads[0], px_grads[1]):
print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
assert 0
if not torch.allclose(py_grads[0], py_grads[1]):
print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
assert 0

for _iter in range(100):
(B, S, T) = (random.randint(1, 10),
random.randint(1, 200),
random.randint(1, 200))
random_px = (random.random() < 0.1)
random_py = (random.random() < 0.1)

print(f"B, S, T = {B}, {S}, {T}, random_px={random_px}, random_py={random_py}")
for dtype in [torch.float32, torch.float64]:
px_grads = []
py_grads = []
m_vals = []
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
print("dtype = ", dtype, ", device = ", device)
B = 2
S = 14
T = 14
boundary = torch.tensor([ 0, 0, S, T ], dtype=torch.int64).unsqueeze(0).expand(B, 4).to(device)

if device == torch.device('cpu'):
if random_px:
px = torch.randn(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
else:
px = torch.zeros(B, S, T + 1, dtype=dtype).to(device) # log of an odds ratio
if random_py:
py = torch.randn(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
else:
py = torch.zeros(B, S + 1, T, dtype=dtype).to(device) # log of an odds ratio
else:
px = px.to(device).detach()
py = py.to(device).detach()
px.requires_grad = True
py.requires_grad = True

#m = mutual_information_recursion(px, py, None)
m = mutual_information_recursion(px, py, boundary)

#print("m = ", m, ", size = ", m.shape)
#print("exp(m) = ", m.exp())
(m.sum() * 3).backward()
#print("px_grad = ", px.grad)
#print("py_grad = ", py.grad)
px_grads.append(px.grad.to('cpu'))
py_grads.append(py.grad.to('cpu'))
m_vals.append(m.to('cpu'))
if not torch.allclose(m_vals[0], m_vals[1], atol=1.0e-05, rtol=1.0e-04):
print(f"m_vals differed CPU vs CUDA: {m_vals[0]} vs. {m_vals[1]}")
assert 0
if not torch.allclose(px_grads[0], px_grads[1], atol=1.0e-05, rtol=1.0e-04):
print(f"px_grads differed CPU vs CUDA: {px_grads[0]} vs. {px_grads[1]}")
assert 0
if not torch.allclose(py_grads[0], py_grads[1], atol=1.0e-05, rtol=1.0e-04):
print(f"py_grads differed CPU vs CUDA: {py_grads[0]} vs. {py_grads[1]}")
assert 0



Expand Down

0 comments on commit c39d5fe

Please sign in to comment.