forked from Ligo-Biosciences/AlphaFold3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_datamodules.py
38 lines (28 loc) · 1.23 KB
/
test_datamodules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pathlib import Path
import pytest
import torch
from src.data.mnist_datamodule import MNISTDataModule
@pytest.mark.parametrize("batch_size", [32, 128])
def test_mnist_datamodule(batch_size: int) -> None:
"""Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary
attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes
correctly match.
:param batch_size: Batch size of the data to be loaded by the dataloader.
"""
data_dir = "data/"
dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size)
dm.prepare_data()
assert not dm.data_train and not dm.data_val and not dm.data_test
assert Path(data_dir, "MNIST").exists()
assert Path(data_dir, "MNIST", "raw").exists()
dm.setup()
assert dm.data_train and dm.data_val and dm.data_test
assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader()
num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test)
assert num_datapoints == 70_000
batch = next(iter(dm.train_dataloader()))
x, y = batch
assert len(x) == batch_size
assert len(y) == batch_size
assert x.dtype == torch.float32
assert y.dtype == torch.int64