diff --git a/mnist_forward_forward/main.py b/mnist_forward_forward/main.py index 25bc187b29..7b651d3083 100644 --- a/mnist_forward_forward/main.py +++ b/mnist_forward_forward/main.py @@ -154,10 +154,10 @@ def train(self, x_pos, x_neg): ] ) train_loader = DataLoader( - MNIST("./data/", train=True, download=True, transform=transform), **train_kwargs + MNIST("../data/", train=True, download=True, transform=transform), **train_kwargs ) test_loader = DataLoader( - MNIST("./data/", train=False, download=True, transform=transform), **test_kwargs + MNIST("../data/", train=False, download=True, transform=transform), **test_kwargs ) net = Net([784, 500, 500]) x, y = next(iter(train_loader))