[toc]
- 继承
torch.utils.data.Dataset
- 实现
__getitem__
和__len__
两个magic methods 【个人倾向于返回字典形式】 - 理解
MNIST
类,以及transforms
模块 - 利用
torchvision.datasets
中的数据集 - 理解
ImageFolder
类及其classes
与class_to_idx
属性
- 利用
torch.utils.data.DataLoader
类 - 理解
__iter__
这个magic method - 区分
Dataloader
与Dataset
的__len__
- 利用 内置函数
enumerate
与tqdm
模块 - 有需要可以更改
collate_fn
- 继承
torch.nn.Module
,注意super().__init__()
- 理解
__call__
这个magic method 与自定义forward
关系 - 注意
PyTorch
中数据的摆放(B, C, H ,W)
- 调用
torchvison.models
中现成的网络 - 注意
torch.nn.Module.dict_state()
torch.save()
torch.load()
以及torch.nn.Module.load_state_dict()
及其中参数 - 利用
torch.utils.model_zoo.load_url()
下载预训练参数
- 调用
torch.optim
模块中的优化器 - 注意
params
参数 - 通过
optimizer.zero_grad()
loss.backward()
optimizer.step()
开始训练
-
综上所述,完成训练!
-
美化代码,下次一定!