Skip to content

Commit

Permalink
revise
Browse files Browse the repository at this point in the history
  • Loading branch information
ssssww0905 committed Nov 19, 2021
1 parent bbab582 commit 963e2d5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
12 changes: 6 additions & 6 deletions 1_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ def eg_1_2_0():
print("type(train_dataset): {}".format(type(train_dataset))) # <class 'torchvision.datasets.mnist.MNIST'>
index = 0
print("train_dataset[{}]: {}".format(index, train_dataset[index])) # (PIL.Image.Image, 5)
print("len(train_dataset): {}".format(len(train_dataset)))

import matplotlib.pyplot as plt
plt.imshow(train_dataset[index][0], cmap ='gray')
print("len(train_dataset): {}".format(len(train_dataset)))


def eg_1_2_1():
Expand Down Expand Up @@ -130,13 +129,14 @@ def eg_1_4_0():
]
)
train_dataset = ImageFolder(root=os.path.join("./flower_data", "train"),
transform=transform)
transform=transform, target_transform=None)

index = 0
print("type(train_dataset[{}]): {}".format(index, type(train_dataset[index]))) # <class 'tuple'>
print("type(train_dataset[{}][0]): {}".format(index, type(train_dataset[index][0]))) # <class 'torch.Tensor'>
print("train_dataset[{}][0].shape: {}".format(index, train_dataset[index][0].shape)) # torch.Size([3, 224, 224])
print("type(train_dataset[{}][1]): {}".format(index, type(train_dataset[index][1]))) # <class 'int'>
print("train_dataset[{}][1]: {}".format(index, train_dataset[index][1])) # 0


def eg_1_4_1():
Expand All @@ -155,12 +155,12 @@ def eg_1_4_1():
]
)
train_dataset = ImageFolder(root=os.path.join("./flower_data", "train"),
transform=transform)
transform=transform, target_transform=None)

print("train_dataset.classes: {}".format(train_dataset.classes)) # ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
print("train_dataset.class_to_idx: {}".format(train_dataset.class_to_idx)) # {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}


from torch.utils.data import Dataset
if __name__ == "__main__":
"""
1.0 torch.utils.data.Dataset
Expand All @@ -177,7 +177,7 @@ def eg_1_4_1():
# eg_1_2_1()
# eg_1_3()
# eg_1_4_0()
# eg_1_4_1()
eg_1_4_1()

print("~~~~~~下课~~~~~~")

Expand Down
12 changes: 7 additions & 5 deletions 2_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def eg_2_1():
"""
Eg2.1 : __iter__
Eg2.1 : __iter__ [magic method]
"""
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset=train_dataset,
Expand All @@ -48,7 +48,7 @@ def eg_2_1():
print("type(batch): {}".format(type(batch))) # <class 'list'>
print("len(batch): {}".format(len(batch))) # 2
print("type(batch[0]): {}".format(type(batch[0]))) # <class 'torch.Tensor'>
print("type(batch[1]): {}".format(type(batch[0]))) # <class 'torch.Tensor'>
print("type(batch[1]): {}".format(type(batch[1]))) # <class 'torch.Tensor'>
print("batch[0].shape: {}".format(batch[0].shape)) # torch.Size([10000, 1, 28, 28])
print("batch[1].shape: {}".format(batch[1].shape)) # torch.Size([10000])
break
Expand Down Expand Up @@ -104,7 +104,9 @@ def eg_2_4():
Eg2.4 : collate_fn
"""
def collate_fn(batch):
print("type(batch): {}, len(batch): {}".format(type(batch), len(batch))) # <class 'list'>, 10000
print("type(batch): {}".format(type(batch))) # <class 'list'>
print("len(batch): {}".format(len(batch))) # 10000
print("type(batch[0]): {}".format(type(batch[0]))) # <class 'tuple'>
x = [i[0] for i in batch]
y = [i[1] for i in batch]
x = torch.cat(x)[:,None,...]
Expand All @@ -129,8 +131,8 @@ def collate_fn(batch):
if __name__ == "__main__":
"""
2.0 torch.utils.data.DataLoader https://pytorch.org/docs/stable/data.html
2.1 __iter__ [magic methods]
2.2 __len__ [magic methods]
2.1 __iter__ [magic method]
2.2 __len__ [magic method]
2.3.0 enumerate
2.3.1 tqdm
2.4 collate_fn
Expand Down

0 comments on commit 963e2d5

Please sign in to comment.