Skip to content

Latest commit

 

History

History
 
 

01_手把手教你跑通第一个神经网络

手把手教你跑通第一个神经网络

参考 PYTORCH DOCUMENTATION

[toc]

# 1_dataset.py

  1. 继承 torch.utils.data.Dataset
  2. 实现 __getitem____len__ 两个magic methods 【个人倾向于返回字典形式】
  3. 理解 MNIST 类,以及 transforms 模块
  4. 利用 torchvision.datasets 中的数据集
  5. 理解 ImageFolder 类及其 classesclass_to_idx 属性

# 2_dataloader.py

  1. 利用 torch.utils.data.DataLoader
  2. 理解 __iter__ 这个magic method
  3. 区分 DataloaderDataset__len__
  4. 利用 内置函数 enumeratetqdm 模块
  5. 有需要可以更改 collate_fn

# 3_model.py

  1. 继承 torch.nn.Module,注意 super().__init__()
  2. 理解 __call__ 这个magic method 与自定义 forward 关系
  3. 注意 PyTorch 中数据的摆放 (B, C, H ,W)
  4. 调用 torchvison.models 中现成的网络
  5. 注意 torch.nn.Module.dict_state() torch.save() torch.load() 以及 torch.nn.Module.load_state_dict() 及其中参数
  6. 利用 torch.utils.model_zoo.load_url() 下载预训练参数

# 4_optimizer.py

  1. 调用 torch.optim 模块中的优化器
  2. 注意 params参数
  3. 通过 optimizer.zero_grad() loss.backward() optimizer.step() 开始训练

# 5_train.py

  • 综上所述,完成训练!

  • 美化代码,下次一定!