Skip to content

Commit

Permalink
added a learning rate finder to get the best learning rate
Browse files Browse the repository at this point in the history
  • Loading branch information
DeepLearning VM committed Jan 10, 2020
1 parent 1da2405 commit 9ac8524
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os, time, torch

import numpy as np
import matplotlib.pyplot as plt

from torchvision import models, transforms
from torch import nn, optim
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch_lr_finder import LRFinder
from Models.Trainer import Trainer
from Data.SnakeDataset import SnakeDataset

Expand Down Expand Up @@ -45,6 +47,12 @@ def train_test(trainer, path):
trainer.train(path)
#trainer.evaluate(path)

def find_lr(model, optimizer, criterion, save_folder):
lr_finder = LRFinder(model, optimizer, criterion)
lr_finder.range_test(data_loaders["train"], end_lr=100, num_iter=100)
lr_finder.plot()
plt.savefig(save_folder + "/LRvsLoss.png")

# Pass in a dictionary for data_map like the following (note that for full datasets positions DON'T have to be included):
#data_map = {
#"train": [path_to_data, path_to_csv, position_start, position_end],
Expand Down Expand Up @@ -82,6 +90,7 @@ def get_loaders(data_map, transforms=transforms.ToTensor(), shuffle=True, batch_
nn.Linear(1000, 85)
)
optimizer = optim.AdamW(model.parameters())
find_lr(model, optimizer, criterion, "Saved/MobileNetV2 - Retrained")
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 1, epochs=100, steps_per_epoch=len(data_loaders["train"]))
trainer = Trainer(model, image_transforms, criterion, optimizer, scheduler, "Saved/MobileNetV2 - Retrained/Model.tar", data_loaders)
train_test(trainer, "Saved/MobileNetV2 - Retrained")
Expand All @@ -91,6 +100,7 @@ def get_loaders(data_map, transforms=transforms.ToTensor(), shuffle=True, batch_
model = models.squeezenet1_1(pretrained=True)
model.classifier[1] = nn.Conv2d(512, 85, (1, 1), (1, 1))
optimizer = optim.AdamW(model.parameters())
find_lr(model, optimizer, criterion, "Saved/SqueezeNet - Subset - Retrained")
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 1, epochs=100, steps_per_epoch=len(data_loaders["train"]))
trainer = Trainer(model, image_transforms, criterion, optimizer, scheduler, "Saved/SqueezeNet - Subset - Retrained/Model.tar", data_loaders)
train_test(trainer, "Saved/SqueezeNet - Subset - Retrained")
Expand All @@ -105,6 +115,7 @@ def get_loaders(data_map, transforms=transforms.ToTensor(), shuffle=True, batch_
nn.Linear(1000, 85)
)
optimizer = optim.AdamW(model.parameters())
find_lr(model, optimizer, criterion, "Saved/ResNet50 - Subset - Retrained")
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 1, epochs=100, steps_per_epoch=len(data_loaders["train"]))
trainer = Trainer(model, image_transforms, criterion, optimizer, scheduler, "Saved/ResNet50 - Subset - Retrained/Model.tar", data_loaders)
train_test(trainer, "Saved/ResNet50 - Subset - Retrained")

0 comments on commit 9ac8524

Please sign in to comment.