Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
wz authored and wz committed Sep 4, 2020
1 parent 97aa5cd commit f081e72
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ flower_data
*.config
*.gz
*.onnx
*.xml
*.bin
*.mapping
checkpoint
data
VOCdevkit
Expand Down
4 changes: 3 additions & 1 deletion pytorch_classification/Test5_resnet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import matplotlib.pyplot as plt
import json

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
Expand All @@ -31,7 +33,7 @@
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
Expand Down

0 comments on commit f081e72

Please sign in to comment.