Skip to content

Commit

Permalink
Revert "fixup resnet"
Browse files Browse the repository at this point in the history
This reverts commit 4eabe67.
  • Loading branch information
geohot committed Jan 16, 2022
1 parent 4eabe67 commit 55d792b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/train_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tinygrad.tensor import Device
from extra.utils import get_parameters
from extra.training import train, evaluate
from models.resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from models.resnet import ResNet18, ResNet34, ResNet50
from tinygrad.optim import Adam
from test.test_mnist import fetch_mnist

Expand Down
22 changes: 9 additions & 13 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, x):
return out

class ResNet:
def __init__(self, block, num_blocks, num_classes=10, url=None, pretrained=False):
def __init__(self, block, num_blocks, num_classes=10, url=None):
self.url = url
self.in_planes = 64

Expand All @@ -64,9 +64,6 @@ def __init__(self, block, num_blocks, num_classes=10, url=None, pretrained=False
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.fc = {"weight": Tensor.uniform(512 * block.expansion, num_classes), "bias": Tensor.zeros(num_classes)}

if pretrained:
self.load_from_pretrained()

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks-1)
layers = []
Expand Down Expand Up @@ -95,13 +92,12 @@ def load_from_pretrained(self):
for k, v in state_dict.items():
obj = get_child(self, k)
dat = v.detach().numpy().T if "fc.weight" in k else v.detach().numpy()
assert obj.shape == dat.shape or k.startswith("fc.")
if obj.shape == dat.shape:
obj.assign(dat)

ResNet18 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [2,2,2,2], num_classes, 'https://download.pytorch.org/models/resnet18-5c106cde.pth', pretrained=pretrained)
ResNet34 = lambda num_classes=1000, pretrained=False: ResNet(BasicBlock, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', pretrained=pretrained)
ResNet50 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,6,3], num_classes, 'https://download.pytorch.org/models/resnet50-19c8e357.pth', pretrained=pretrained)
ResNet101 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,4,23,3], num_classes, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', pretrained=pretrained)
ResNet152 = lambda num_classes=1000, pretrained=False: ResNet(Bottleneck, [3,8,36,3], num_classes, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', pretrained=pretrained)
assert obj.shape == dat.shape
obj.assign(dat)

ResNet18 = lambda: ResNet(BasicBlock, [2,2,2,2], 1000, 'https://download.pytorch.org/models/resnet18-5c106cde.pth')
ResNet34 = lambda: ResNet(BasicBlock, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet34-333f7ec4.pth')
ResNet50 = lambda: ResNet(Bottleneck, [3,4,6,3], 1000, 'https://download.pytorch.org/models/resnet50-19c8e357.pth')
ResNet101 = lambda: ResNet(Bottleneck, [3,4,23,3], 1000, 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
ResNet101 = lambda: ResNet(Bottleneck, [3,8,36,3], 1000, 'https://download.pytorch.org/models/resnet152-b121ed2d.pth')

0 comments on commit 55d792b

Please sign in to comment.