Skip to content

Commit

Permalink
Fix some bugs..
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Jul 29, 2021
1 parent 52ae49e commit 1ad556d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 81 deletions.
2 changes: 1 addition & 1 deletion torch_mutual_information/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .mutual_information import mutual_information
from .mutual_information import mutual_information_recursion
49 changes: 28 additions & 21 deletions torch_mutual_information/mutual_information.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

import torch
from typing import Tuple
from torch import Tensor
from typing import Tuple, Optional
from torch.utils.cpp_extension import load

VERBOSE = False
Expand Down Expand Up @@ -44,18 +45,18 @@ def _resolve(name):


def _mutual_information_forward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
boundary: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
if px.is_cuda:
if torch_mutual_information_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_mutual_information_cuda.mutual_information_cuda(
px, py, boundaries, p)
px, py, boundary, p)
else:
return torch_mutual_information_cpu.mutual_information_cpu(
px, py, boundaries, p)
px, py, boundary, p)

def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
boundaries: torch.Tensor, p: torch.Tensor,
boundary: torch.Tensor, p: torch.Tensor,
ans_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if px.is_cuda:
if torch_mutual_information_cuda is None:
Expand All @@ -64,26 +65,28 @@ def _mutual_information_backward_dispatcher(px: torch.Tensor, py: torch.Tensor,
if overwrite_ans_grad:
ans_grad_copy = ans_grad.clone()
ans = tuple(torch_mutual_information_cuda.mutual_information_backward_cuda(
px, py, boundaries, p, ans_grad_copy, overwrite_ans_grad))
px, py, boundary, p, ans_grad_copy, overwrite_ans_grad))
if overwrite_ans_grad:
if not torch.allclose(ans_grad, ans_grad_copy, rtol=1.0e-02):
print(f"Warning: possible excsssive roundoff in mutual information backward "
"recursion: {ans_grad} vs. {ans_grad_copy}");
return ans
else:
return tuple(torch_mutual_information_cpu.mutual_information_backward_cpu(
px, py, boundaries, p, ans_grad))
px, py, boundary, p, ans_grad))



class MutualInformationRecursionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: Optional[torch.Tensor]) -> torch.Tensor:
def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundary: Optional[torch.Tensor]) -> torch.Tensor:
(B, S, T1) = px.shape
T = T1 - 1;
assert py.shape == (B, S + 1, T)
if boundaries is not None:
assert boundaries.shape == (B, 4)
if boundary is not None:
assert boundary.shape == (B, 4)
else:
boundary = torch.zeros(0, 0, dtype=torch.int64, device=px.device)


# p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the
Expand All @@ -101,20 +104,23 @@ def forward(ctx, px: torch.Tensor, py: torch.Tensor, boundaries: Optional[torch.

p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype)

ans = _mutual_information_forward_dispatcher(px, py, boundaries, p)
ans = _mutual_information_forward_dispatcher(px, py, boundary, p)

print(f"p = {p}, boundary = {boundary}")

if px.requires_grad or py.requires_grad:
ctx.save_for_backward(px, py, boundaries, p)
ctx.save_for_backward(px, py, boundary, p)
return ans

@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
(px, py, boundaries, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundaries, p, ans_grad)
(px, py, boundary, p) = ctx.saved_tensors
(px_grad, py_grad) = _mutual_information_backward_dispatcher(px, py, boundary, p, ans_grad)
return (px_grad, py_grad, None)



def mutual_information_recursion(input, px, py, boundaries=None):
def mutual_information_recursion(px, py, boundary=None):
"""A recursion that is useful in computing mutual information between two sequences of
real vectors, but may be useful more generally in sequence-to-sequence tasks where
monotonic alignment between pairs of sequences is desired. The definitions of
Expand Down Expand Up @@ -154,7 +160,7 @@ def mutual_information_recursion(input, px, py, boundaries=None):
is that for optimization purposes we assume the last axis (the t axis)
has stride of 1; this is true if px and py are contiguous.
boundaries: If supplied, a torch.LongTensor of shape [B][4], where each row contains
boundary: If supplied, a torch.LongTensor of shape [B][4], where each row contains
[s_begin, t_begin, s_end, t_end]. If not supplied, the values
[0, 0, S, T] will be assumed. These are the beginning and
one-past-the-last positions in the x and y sequences
Expand Down Expand Up @@ -182,8 +188,9 @@ def mutual_information_recursion(input, px, py, boundaries=None):
assert py.shape == (B, S + 1, T)
assert px.dtype == py.dtype
(B, S, T) = px.shape
if boundaries is not None:
assert boundaries.dtype == torch.LongTensor
assert boundaries.shape == (B, 4)
if boundary is not None:
assert boundary.dtype == torch.LongTensor
assert boundary.shape == (B, 4)


return MutualInformationRecursion.apply(px, py, boundaries)
return MutualInformationRecursionFunction.apply(px, py, boundary)
6 changes: 3 additions & 3 deletions torch_mutual_information/mutual_information_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ void mutual_information_backward_kernel(
// comments. We'll focus, in the comments, on differences from the forward
// pass.
const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1,
// num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter = min(iter + 1, num_s_blocks);


Expand Down Expand Up @@ -668,7 +668,7 @@ void mutual_information_backward_kernel(
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} else if (static_cast<unsigned int>((int)threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) {
// casting to unsigned before the comparison tests for both negative and
Expand All @@ -678,7 +678,7 @@ void mutual_information_backward_kernel(
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[s][t] : 0.0);
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
}

// The highest-numbered value in p_buf that we need (corresponding,
Expand Down
64 changes: 8 additions & 56 deletions torch_mutual_information/mutual_information_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,24 @@

import random
import torch
from torch_mutual_information import mutual_information
from torch_mutual_information import mutual_information_recursion


def test_mutual_information_basic():
print("Running test_mutual_information_basic()")
for dtype in [torch.float32, torch.float64]:
B = 2
C = 4
T = 10
x = -2.0 + 0.4 * torch.arange(10, dtype=dtype)
x = x.reshape(1, 1, 10).repeat(B, C, 1)
S = 4
T = 5
px = torch.zeros(B, S, T + 1) # log of an odds ratio
py = torch.zeros(B, S + 1, T) # log of an odds ratio


K = 4
N = K * 2
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1) - 3
x.requires_grad = True
params.requires_grad = True
print("x = ", x)
print("params = ", params)
print("x.shape = ", x.shape)
m = mutual_information_recursion(px, py)

y = mutual_information(x, params, dim = 1)
print("m = ", m)


if True:
# Check
x2 = x.reshape(B, C, 5, 2)
assert torch.allclose(mutual_information(x, params, dim = 1), mutual_information(x2, params, dim = 1).reshape(x.shape))

x2 = x.reshape(B, 1, C, 10)
assert torch.allclose(mutual_information(x, params, dim = 1), mutual_information(x2, params, dim = 2).reshape(x.shape))



print("y = ", y)
y.sum().backward()

if torch.cuda.is_available():
# test that the CUDA forward is the same as the CPU forward.
device = torch.device('cuda:0')
x2 = x.to(device).detach()
x2.requires_grad = True
params2 = params.to(device).detach()
params2.requires_grad = True
y2 = mutual_information(x2, params2, dim = 1).to(torch.device('cpu'))
print("Checking CUDA is same")
if not torch.allclose(y, y2, atol=1.0e-06):
print(f"Error: CPU versus CUDA not the same: {y} vs. {y2}, diff = {y2-y}")
assert(0);

y2.sum().backward()

if not torch.allclose(x.grad, x2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU x-grad versus CUDA grad not the same: {x.grad} vs. {x2.grad}, diff = {x2.grad.to('cpu')-x.grad}")
assert(0);
if not torch.allclose(params.grad, params2.grad.to('cpu'), atol=1.0e-06):
print(f"Error: CPU params-grad versus CUDA grad not the same: {params.grad} vs. {params2.grad}, diff = {params2.grad.to('cpu')-params.grad}")
assert(0);



print("x.grad = ", x.grad)
print("params.grad = ", params.grad)

# Just eyeballing the above tgo make sure it looks reasonable.


def test_mutual_information_deriv():
""" Tests derivatives in randomized way """
Expand Down

0 comments on commit 1ad556d

Please sign in to comment.