Skip to content

Commit

Permalink
add operator creating fun
Browse files Browse the repository at this point in the history
  • Loading branch information
AnderBiguri committed Oct 22, 2024
1 parent f1f5c74 commit 3786b4e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
1 change: 0 additions & 1 deletion Python/demos/d25_Pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def test_3D_operators(
TIGRE_GPU_ID = gpu.getGpuIds(listGpuNames[3])
TIGRE_GPU_ID = TIGRE_GPU_ID[3]
PYTORCH_GPU_ID = TIGRE_GPU_ID.devices[0]
PYTORCH_GPU_ID = 3

print(f'Using GPU {TIGRE_GPU_ID} for TIGRE and GPU {PYTORCH_GPU_ID} for PyTorch')
#geo=get_default_2Dgeometry()
Expand Down
38 changes: 33 additions & 5 deletions Python/tigre/utilities/pytorch_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def backward(ctx, grad_output:torch.Tensor): #type:ignore
return volume_output, None, None, None, None

class A(torch.nn.Module):
def __init__(self, geo:Geometry, gpuids:List[str]):
def __init__(self, geo:Geometry, angles, gpuids:List[str]):
super(A, self).__init__()
assert geo.angles is not None, 'Initialise the angles'
geo.angles = angles
self.geo = geo
self.gpuids = gpuids
self.volume_dimension = geo.nVoxel.tolist()
Expand Down Expand Up @@ -137,13 +137,41 @@ def backward(ctx, grad_output:torch.Tensor): #type:ignore
return sinogram_output, None, None, None, None

class At(torch.nn.Module):
def __init__(self, geo:Geometry, gpuids:List[str]):
def __init__(self, geo:Geometry, angles:np.array, gpuids:List[str]):
super(At, self).__init__()
assert geo.angles is not None, 'Initialise the angles'
geo.angles = angles
self.geo = geo
self.gpuids = gpuids
self.volume_dimension = geo.nVoxel.tolist()
self.sinogram_dimension = [len(geo.angles)] + geo.nDetector.tolist()

def forward(self, x:torch.Tensor):
return ATBFunction.apply(x, self.geo, self.gpuids, self.volume_dimension, self.sinogram_dimension)
return ATBFunction.apply(x, self.geo, self.gpuids, self.volume_dimension, self.sinogram_dimension)





def create_pytorch_operator(geo, angles, gpuids):
return A(geo, angles, gpuids), At(geo, angles, gpuids)



## This may be useful for non-torch stuff, but doesn't work for torch autograd.
# I'll leave it here for now.
class Operator:
def __init__(self, geo, angles, gpuids):
super(Operator, self).__init__()
self.geo = geo
self.angles = angles
self.gpuids = gpuids
self.ax = A(self.geo, self.angles, self.gpuids)
self.atb = At(self.geo, self.angles, self.gpuids)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
return self.ax(x)
def T(self,b):
return self.backward(b)
def backward(self, b):
return self.atb(b)

0 comments on commit 3786b4e

Please sign in to comment.