Skip to content

Commit

Permalink
start on mlperf models
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed May 10, 2023
1 parent d13629c commit 46d4190
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 5 deletions.
2 changes: 1 addition & 1 deletion datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def fetch_cifar(train=True):
cifar10_mean = np.array([0.4913997551666284, 0.48215855929893703, 0.4465309133731618], dtype=np.float32).reshape(1,3,1,1)
cifar10_std = np.array([0.24703225141799082, 0.24348516474564, 0.26158783926049628], dtype=np.float32).reshape(1,3,1,1)
fn = os.path.dirname(__file__)+"/cifar-10-python.tar.gz"
download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn, True)
download_file('https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fn)
tt = tarfile.open(fn, mode='r:gz')
if train:
db = [pickle.load(tt.extractfile(f'cifar-10-batches-py/data_batch_{i}'), encoding="bytes") for i in range(1,6)]
Expand Down
19 changes: 19 additions & 0 deletions examples/mlperf/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Each model should be a clean single file.
They are imported from the top level `models` directory

It should be capable of loading weights from the reference imp.

We will focus on these 5 models:

# Resnet50-v1.5 (classic) -- 8.2 GOPS/input
# Retinanet
# 3D UNET (upconvs)
# RNNT
# BERT-large (transformer)

They are used in both the training and inference benchmark:
https://mlcommons.org/en/training-normal-21/
https://mlcommons.org/en/inference-edge-30/
And we will submit to both.

NOTE: we are Edge since we don't have ECC RAM
33 changes: 33 additions & 0 deletions examples/mlperf/model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# load each model here, quick benchmark
from tinygrad.tensor import Tensor
from tinygrad.helpers import GlobalCounters

def test_model(model, *inputs):
GlobalCounters.reset()
model(*inputs).numpy()
# TODO: return event future to still get the time_sum_s without DEBUG=2
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms")

if __name__ == "__main__":
# inference only for now
Tensor.training = False
Tensor.no_grad = True

# Resnet50-v1.5
"""
from models.resnet import ResNet50
mdl = ResNet50()
img = Tensor.randn(1, 3, 224, 224)
test_model(mdl, img)
"""

# Retinanet

# 3D UNET
from models.unet3d import UNet3D
mdl = UNet3D()
mdl.load_from_pretrained()

# RNNT

# BERT-large
3 changes: 1 addition & 2 deletions examples/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,7 @@ def __init__(self):
# load in weights
download_file(
'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',
FILENAME,
skip_if_exists=True
FILENAME
)
dat = fake_torch_load_zipped(open(FILENAME, "rb"))
for k,v in dat['state_dict'].items():
Expand Down
2 changes: 1 addition & 1 deletion extra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def fetch(url):
with open(fp, "rb") as f:
return f.read()

def download_file(url, fp, skip_if_exists=False):
def download_file(url, fp, skip_if_exists=True):
import requests, os
if skip_if_exists and os.path.isfile(fp) and os.stat(fp).st_size > 0:
return
Expand Down
3 changes: 2 additions & 1 deletion models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, num, num_classes):
self.layer2 = self._make_layer(self.block, 128, self.num_blocks[1], stride=2)
self.layer3 = self._make_layer(self.block, 256, self.num_blocks[2], stride=2)
self.layer4 = self._make_layer(self.block, 512, self.num_blocks[3], stride=2)
# TODO: replace with nn.Linear
self.fc = {"weight": Tensor.scaled_uniform(512 * self.block.expansion, num_classes), "bias": Tensor.zeros(num_classes)}

def _make_layer(self, block, planes, num_blocks, stride):
Expand All @@ -105,7 +106,7 @@ def __call__(self, x):

def load_from_pretrained(self):
# TODO replace with fake torch load

model_urls = {
18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
Expand Down
31 changes: 31 additions & 0 deletions models/unet3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# https://github.com/wolny/pytorch-3dunet
from pathlib import Path
from extra.utils import download_file, fake_torch_load
import tinygrad.nn as nn

class SingleConv:
def __init__(self, in_channels, out_channels):
self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group?
self.conv = nn.Conv2d(in_channels, out_channels, (3,3,3), bias=False)
def __call__(self, x):
return self.conv(self.groupnorm(x)).relu()

def get_basic_module(c0, c1, c2): return {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)}

class UNet3D:
def __init__(self):
ups = [16,32,64,128,256]
self.encoders = [get_basic_module(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)]
self.decoders = [get_basic_module(ups[-1-i] + ups[-2+i], ups[-2+i], ups[-2+i]) for i in range(3)]
self.final_conv = nn.Conv2d(32, 1, (1,1,1))

def __call__(self, x):
# TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3)
pass

def load_from_pretrained(self):
fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt"
download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn)
state = fake_torch_load(open(fn, "rb").read())['model_state_dict']
for x in state.keys():
print(x, state[x].shape)

0 comments on commit 46d4190

Please sign in to comment.