forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
88 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
Each model should be a clean single file. | ||
They are imported from the top level `models` directory | ||
|
||
It should be capable of loading weights from the reference imp. | ||
|
||
We will focus on these 5 models: | ||
|
||
# Resnet50-v1.5 (classic) -- 8.2 GOPS/input | ||
# Retinanet | ||
# 3D UNET (upconvs) | ||
# RNNT | ||
# BERT-large (transformer) | ||
|
||
They are used in both the training and inference benchmark: | ||
https://mlcommons.org/en/training-normal-21/ | ||
https://mlcommons.org/en/inference-edge-30/ | ||
And we will submit to both. | ||
|
||
NOTE: we are Edge since we don't have ECC RAM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# load each model here, quick benchmark | ||
from tinygrad.tensor import Tensor | ||
from tinygrad.helpers import GlobalCounters | ||
|
||
def test_model(model, *inputs): | ||
GlobalCounters.reset() | ||
model(*inputs).numpy() | ||
# TODO: return event future to still get the time_sum_s without DEBUG=2 | ||
print(f"{GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.time_sum_s*1000:.2f} ms") | ||
|
||
if __name__ == "__main__": | ||
# inference only for now | ||
Tensor.training = False | ||
Tensor.no_grad = True | ||
|
||
# Resnet50-v1.5 | ||
""" | ||
from models.resnet import ResNet50 | ||
mdl = ResNet50() | ||
img = Tensor.randn(1, 3, 224, 224) | ||
test_model(mdl, img) | ||
""" | ||
|
||
# Retinanet | ||
|
||
# 3D UNET | ||
from models.unet3d import UNet3D | ||
mdl = UNet3D() | ||
mdl.load_from_pretrained() | ||
|
||
# RNNT | ||
|
||
# BERT-large |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# https://github.com/wolny/pytorch-3dunet | ||
from pathlib import Path | ||
from extra.utils import download_file, fake_torch_load | ||
import tinygrad.nn as nn | ||
|
||
class SingleConv: | ||
def __init__(self, in_channels, out_channels): | ||
self.groupnorm = nn.GroupNorm(1, in_channels) # 1 group? | ||
self.conv = nn.Conv2d(in_channels, out_channels, (3,3,3), bias=False) | ||
def __call__(self, x): | ||
return self.conv(self.groupnorm(x)).relu() | ||
|
||
def get_basic_module(c0, c1, c2): return {"SingleConv1": SingleConv(c0, c1), "SingleConv2": SingleConv(c1, c2)} | ||
|
||
class UNet3D: | ||
def __init__(self): | ||
ups = [16,32,64,128,256] | ||
self.encoders = [get_basic_module(ups[i] if i != 0 else 1, ups[i], ups[i+1]) for i in range(4)] | ||
self.decoders = [get_basic_module(ups[-1-i] + ups[-2+i], ups[-2+i], ups[-2+i]) for i in range(3)] | ||
self.final_conv = nn.Conv2d(32, 1, (1,1,1)) | ||
|
||
def __call__(self, x): | ||
# TODO: make 2D conv generic for 3D, might already work with kernel_size=(3,3,3) | ||
pass | ||
|
||
def load_from_pretrained(self): | ||
fn = Path(__file__).parent.parent / "weights/unet-3d.ckpt" | ||
download_file("https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FLateral-Root-Primordia%2Funet_bce_dice_ds1x&files=best_checkpoint.pytorch", fn) | ||
state = fake_torch_load(open(fn, "rb").read())['model_state_dict'] | ||
for x in state.keys(): | ||
print(x, state[x].shape) |