|
| 1 | +from __future__ import print_function |
| 2 | +import ml_datasets |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +import torch.optim as optim |
| 7 | +from thinc.util import xp2torch, get_shuffled_batches |
| 8 | +import tqdm |
| 9 | + |
| 10 | + |
| 11 | +class Net(nn.Module): |
| 12 | + def __init__(self, n_class, n_in, n_hidden, dropout=0.2): |
| 13 | + super(Net, self).__init__() |
| 14 | + self.fc1 = nn.Linear(n_in, n_hidden) |
| 15 | + self.dropout1 = nn.Dropout2d(0.2) |
| 16 | + self.fc2 = nn.Linear(n_hidden, n_hidden) |
| 17 | + self.dropout2 = nn.Dropout2d(0.2) |
| 18 | + self.fc3 = nn.Linear(n_hidden, n_class) |
| 19 | + |
| 20 | + def forward(self, x): |
| 21 | + x = x.view(-1, 28*28) |
| 22 | + x = self.fc1(x) |
| 23 | + x = F.relu(x) |
| 24 | + x = self.dropout1(x) |
| 25 | + x = self.fc2(x) |
| 26 | + x = F.relu(x) |
| 27 | + x = self.dropout2(x) |
| 28 | + x = self.fc3(x) |
| 29 | + output = F.log_softmax(x, dim=-1) |
| 30 | + return output |
| 31 | + |
| 32 | + |
| 33 | +def load_mnist(): |
| 34 | + from thinc.backends import NumpyOps |
| 35 | + from thinc.util import to_categorical |
| 36 | + ops = NumpyOps() |
| 37 | + mnist_train, mnist_dev, _ = ml_datasets.mnist() |
| 38 | + train_X, train_Y = ops.unzip(mnist_train) |
| 39 | + dev_X, dev_Y = ops.unzip(mnist_dev) |
| 40 | + train_Y = train_Y.astype("int64") |
| 41 | + dev_Y = dev_Y.astype("int64") |
| 42 | + return (train_X, train_Y), (dev_X, dev_Y) |
| 43 | + |
| 44 | + |
| 45 | +def test(args, model, device, test_loader): |
| 46 | + model.eval() |
| 47 | + test_loss = 0 |
| 48 | + correct = 0 |
| 49 | + with torch.no_grad(): |
| 50 | + for data, target in test_loader: |
| 51 | + data, target = data.to(device), target.to(device) |
| 52 | + output = model(data) |
| 53 | + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss |
| 54 | + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability |
| 55 | + correct += pred.eq(target.view_as(pred)).sum().item() |
| 56 | + |
| 57 | + test_loss /= len(test_loader.dataset) |
| 58 | + |
| 59 | + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
| 60 | + test_loss, correct, len(test_loader.dataset), |
| 61 | + 100. * correct / len(test_loader.dataset))) |
| 62 | + |
| 63 | + |
| 64 | +def main(n_hidden=32, dropout=0.2, n_iter=10, batch_size=128, n_epoch=10): |
| 65 | + torch.set_num_threads(1) |
| 66 | + (train_X, train_Y), (dev_X, dev_Y) = load_mnist() |
| 67 | + model = Net(10, 28*28, n_hidden) |
| 68 | + optimizer = optim.Adam(model.parameters()) |
| 69 | + |
| 70 | + for epoch in range(n_epoch): |
| 71 | + model.train() |
| 72 | + train_batches = list(get_shuffled_batches(train_X, train_Y, batch_size)) |
| 73 | + for images, true_labels in tqdm.tqdm(train_batches): |
| 74 | + images = xp2torch(images) |
| 75 | + true_labels = xp2torch(true_labels) |
| 76 | + |
| 77 | + optimizer.zero_grad() |
| 78 | + guess_labels = model(images) |
| 79 | + loss = F.nll_loss(guess_labels, true_labels) |
| 80 | + loss.backward() |
| 81 | + optimizer.step() |
| 82 | + |
| 83 | + |
| 84 | +if __name__ == '__main__': |
| 85 | + import plac |
| 86 | + plac.call(main) |
0 commit comments