Skip to content

Commit

Permalink
conv3d + unet3d (tinygrad#772)
Browse files Browse the repository at this point in the history
* conv3d, needs test

* test passes, padding wrong on unet

* unet3d

* no conv3d on images
  • Loading branch information
geohot authored May 12, 2023
1 parent 46d4190 commit 810f03d
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 29 deletions.
2 changes: 2 additions & 0 deletions examples/mlperf/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def test_model(model, *inputs):
from models.unet3d import UNet3D
mdl = UNet3D()
mdl.load_from_pretrained()
img = Tensor.randn(1, 1, 5, 224, 224)
test_model(mdl, img)

# RNNT

Expand Down
31 changes: 21 additions & 10 deletions models/unet3d.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
# https://github.com/wolny/pytorch-3dunet
from pathlib import Path
from extra.utils import download_file, fake_torch_load
from extra.utils import download_file, fake_torch_load, get_child
import tinygrad.nn as nn

class SingleConv:
def __init__(self, in_channels, out_channels):
self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group?
self.conv = nn.Conv2d(in_channels, out_channels, (3,3,3), bias=False)
# TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False)
def __call__(self, x):
return self.conv(self.groupnorm(x)).relu()

def get_basic_module(c0, c1, c2): return {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)}
class BasicModule:
def __init__(self, c0, c1, c2):
self.basic_module = {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)}
def __call__(self, x):
return self.basic_module['SingleConv2'](self.basic_module['SingleConv1'](x))

class UNet3D:
def __init__(self):
ups = [16,32,64,128,256]
self.encoders = [get_basic_module(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)]
self.decoders = [get_basic_module(ups[-1-i] + ups[-2+i], ups[-2+i], ups[-2+i]) for i in range(3)]
self.encoders = [BasicModule(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)]
self.decoders = [BasicModule(ups[-1-i] + ups[-2-i], ups[-2-i], ups[-2-i]) for i in range(3)]
self.final_conv = nn.Conv2d(32, 1, (1,1,1))

def __call__(self, x):
# TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3)
pass
intermediates = [x]
for e in self.encoders: intermediates.append(e(intermediates[-1]))
ret = intermediates[-1]
for d,i in zip(self.decoders, intermediates[:-1][::-1]): ret = d(ret.cat(i, dim=1))
return ret

def load_from_pretrained(self):
fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt"
download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn)
state = fake_torch_load(open(fn, "rb").read())['model_state_dict']
for x in state.keys():
print(x, state[x].shape)
state_dict = fake_torch_load(open(fn, "rb").read())['model_state_dict']
for k, v in state_dict.items():
print(k, v.shape)
obj = get_child(self, k)
assert obj.shape == v.shape, (k, obj.shape, v.shape)
obj.assign(v.numpy())
20 changes: 16 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, IMAGE
from tinygrad.lazy import Device

FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
Expand Down Expand Up @@ -346,6 +346,18 @@ def test_simple_conv2d(self):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)

@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv3d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)

@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_padded_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5)

def test_simple_conv2d_m4(self):
helper_test_op([(1,16,18,18), (16,16,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
Expand Down Expand Up @@ -580,10 +592,10 @@ def test_stack(self):

for dim in range(-1, 3):
helper_test_op([(45, 65, 3), (45, 65, 3), (45, 65, 3)], lambda x, y, z: torch.stack((x, y, z), dim=dim), lambda x, y, z: Tensor.stack([x, y, z], dim=dim))

with self.assertRaises(IndexError):
Tensor.stack([x], dim=77)

def test_repeat(self):
x = Tensor.randn(45, 65, 3)
base_repeats = [2, 4, 3]
Expand All @@ -597,7 +609,7 @@ def test_repeat(self):

with self.assertRaises(AssertionError):
x.repeat((2, 0, 4))


def test_clip(self):
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __call__(self, x:Tensor):
# TODO: is this good weight init?
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else (kernel_size[0], kernel_size[1])
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
self.weight = Tensor.glorot_uniform(out_channels, in_channels//groups, self.kernel_size[0], self.kernel_size[1])
self.weight = Tensor.glorot_uniform(out_channels, in_channels//groups, *self.kernel_size)
self.bias = Tensor.zeros(out_channels) if bias else None

def __call__(self, x):
Expand All @@ -65,7 +65,7 @@ def __call__(self, x:Tensor):

if self.weight is None or self.bias is None: return x
# elementwise_affine on channels
return x * self.weight.reshape(1, -1, 1, 1) + self.bias.reshape(1, -1, 1, 1)
return x * self.weight.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)]) + self.bias.reshape(1, -1, *[1 for _ in range(len(x.shape)-2)])

class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Opti
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(28-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/shape/shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def expand(self, new_shape: Tuple[int, ...]):

def reshape(self, new_shape: Tuple[int, ...]):
if self.shape == new_shape: return self
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}"
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"

# check if this is adding or removing 1s (only)
Expand Down
23 changes: 13 additions & 10 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,10 @@ def unsqueeze(self, dim):
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])

# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Union[List[int], Tuple[int, ...]]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])))
def pad2d(self, padding:Union[List[int], Tuple[int, ...]]):
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc)

@property
def T(self) -> Tensor: return self.transpose()
def transpose(self, ax1=1, ax2=0) -> Tensor:
Expand Down Expand Up @@ -360,24 +363,24 @@ def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pai
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))

def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
(bs,cin_,_,_), (cout,cin,H,W) = self.shape, weight.shape
assert groups*cin == cin_, f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})"
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) >= 4 else [padding[1], padding[1], padding[0], padding[0]])

# conv2d is a pooling op (with padding)
x = self.pad2d(padding_)._pool((H,W), stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oy, ox = cout//groups, x.shape[2], x.shape[3]
x = x.reshape(bs, groups, cin, 1, oy, ox, H, W).expand(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)
x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W)
rcout, oyx = cout//groups, x.shape[2:-len(HW)]
x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))])

# expand the channels with the pool
# TODO: this reduces the number of kernels, but it's slower!
#x = self.pad2d(padding_)._pool((H,W), stride, dilation, _insert_dims=(cout//groups,)) # (bs, groups*cin, rcout, oy, ox, H, W)
#rcout, oy, ox = x.shape[2:5]
#x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7)

# conv! broadcasted to (bs, groups, rcout, oy, ox, cin, H, W)
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1), keepdim=True).reshape(bs, cout, oy, ox)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
# conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW)
ret = (x * weight.reshape(1, groups, rcout, *[1 for _ in range(len(oyx))], cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)
return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))

def dot(self, w:Tensor) -> Tensor:
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
Expand Down

0 comments on commit 810f03d

Please sign in to comment.