diff --git a/lpips/__init__.py b/lpips/__init__.py index 9dc89877..89e5bba2 100755 --- a/lpips/__init__.py +++ b/lpips/__init__.py @@ -73,6 +73,7 @@ def load_image(path): import cv2 return cv2.imread(path)[:,:,::-1] else: + import matplotlib.pyplot as plt img = (255*plt.imread(path)[:,:,:3]).astype('uint8') return img