Skip to content

Commit

Permalink
根据权重名称 判断加载 res50
Browse files Browse the repository at this point in the history
  • Loading branch information
ldcah committed Jun 30, 2020
1 parent ab81602 commit a4d62a2
Showing 1 changed file with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from PIL import Image
from torchvision import transforms
from model import resnet34
from model import resnet34, resnet50


class MyModel4Prdict():
Expand All @@ -27,7 +27,10 @@ def __init__(self, model_weight_path, json_file):
self.class_indict = json.load(json_filef)

# create model
self.model = resnet34(num_classes=len(self.class_indict))
if (str.find(model_weight_path, "res50.pth") > 0):
self.model = resnet50(num_classes=len(self.class_indict)) # res50
else:
self.model = resnet34(num_classes=len(self.class_indict))
# load model weights
self.model.load_state_dict(torch.load(model_weight_path))
self.model.eval()
Expand Down

0 comments on commit a4d62a2

Please sign in to comment.