Skip to content

Commit

Permalink
Add custom color & thickness for predicted boxes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwhite committed Aug 9, 2021
1 parent 6a41dc9 commit 7653161
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 17 deletions.
20 changes: 14 additions & 6 deletions deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def val_dataloader(self):

return loader

def predict_image(self, image=None, path=None, return_plot=False):
def predict_image(self, image=None, path=None, return_plot=False, color=None, thickness=1):
"""Predict a single image with a deepforest model
Args:
Expand Down Expand Up @@ -242,7 +242,9 @@ def predict_image(self, image=None, path=None, return_plot=False):
image=image,
return_plot=return_plot,
device=self.current_device,
iou_threshold=self.config["nms_thresh"])
iou_threshold=self.config["nms_thresh"],
color=color,
thickness=thickness)

#Set labels to character from numeric if returning boxes df
if not return_plot:
Expand All @@ -251,7 +253,7 @@ def predict_image(self, image=None, path=None, return_plot=False):

return result

def predict_file(self, csv_file, root_dir, savedir=None):
def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1):
"""Create a dataset and predict entire annotation file
Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
Expand All @@ -273,7 +275,9 @@ def predict_file(self, csv_file, root_dir, savedir=None):
root_dir=root_dir,
savedir=savedir,
device=self.current_device,
iou_threshold=self.config["nms_thresh"])
iou_threshold=self.config["nms_thresh"],
color=color,
thickness=thickness)

#Set labels to character from numeric
result["label"] = result.label.apply(lambda x: self.numeric_to_label_dict[x])
Expand All @@ -289,7 +293,9 @@ def predict_tile(self,
return_plot=False,
use_soft_nms=False,
sigma=0.5,
thresh=0.001):
thresh=0.001,
color=None,
thickness=1):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Expand Down Expand Up @@ -331,7 +337,9 @@ def predict_tile(self,
use_soft_nms=use_soft_nms,
sigma=sigma,
thresh=thresh,
device=self.current_device)
device=self.current_device,
color=color,
thickness=thickness)

#edge case, if no boxes predictioned return None
if result is None:
Expand Down
14 changes: 8 additions & 6 deletions deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deepforest import visualize
from deepforest import dataset

def predict_image(model, image, return_plot, device, iou_threshold=0.1):
def predict_image(model, image, return_plot, device, iou_threshold=0.1, color=None, thickness=1):
"""Predict an image with a deepforest model
Args:
Expand Down Expand Up @@ -54,14 +54,14 @@ def predict_image(model, image, return_plot, device, iou_threshold=0.1):
image = np.rollaxis(image, 0, 3)
image = image[:,:,::-1] * 255
image = image.astype("uint8")
image = visualize.plot_predictions(image, df)
image = visualize.plot_predictions(image, df, color=color, thickness=thickness)

return image
else:
return df


def predict_file(model, csv_file, root_dir, savedir, device, iou_threshold=0.1):
def predict_file(model, csv_file, root_dir, savedir, device, iou_threshold=0.1, color=(0,165,255), thickness=1):
"""Create a dataset and predict entire annotation file
Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position.
Expand Down Expand Up @@ -109,7 +109,7 @@ def predict_file(model, csv_file, root_dir, savedir, device, iou_threshold=0.1):
#Plot annotations if they exist
annotations = df[df.image_path == paths[index]]

image = visualize.plot_predictions(image, annotations, color=(0,165,255))
image = visualize.plot_predictions(image, annotations, color=color, thickness=thickness)
cv2.imwrite("{}/{}.png".format(savedir, os.path.splitext(paths[index])[0]), image)


Expand All @@ -131,7 +131,9 @@ def predict_tile(model,
return_plot=False,
use_soft_nms=False,
sigma=0.5,
thresh=0.001):
thresh=0.001,
color=None,
thickness=1):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Expand Down Expand Up @@ -235,7 +237,7 @@ def predict_tile(model,
if return_plot:
# Draw predictions on BGR
image = image[:,:,::-1]
image = visualize.plot_predictions(image, mosaic_df)
image = visualize.plot_predictions(image, mosaic_df, color=color, thickness=thickness)
# Mantain consistancy with predict_image
return image
else:
Expand Down
12 changes: 7 additions & 5 deletions deepforest/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pandas.api.types as ptypes
import cv2
import random
import warnings

def view_dataset(ds, savedir=None):
def view_dataset(ds, savedir=None, color=None, thickness=1):
"""Plot annotations on images for debugging purposes
Args:
ds: a deepforest pytorch dataset, see deepforest.dataset or deepforest.load_dataset() to start from a csv file
Expand All @@ -20,7 +21,7 @@ def view_dataset(ds, savedir=None):
image_path, image, targets = i
df = format_boxes(targets[0], scores=False)
image = np.moveaxis(image[0].numpy(),0,2)
image = plot_predictions(image, df)
image = plot_predictions(image, df, color=color, thickness=thickness)

if savedir:
cv2.imwrite("{}/{}".format(savedir, image_path[0]), image)
Expand Down Expand Up @@ -91,7 +92,7 @@ def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None):

return written_figures

def plot_predictions(image, df, color=None):
def plot_predictions(image, df, color=None, thickness=1):
"""Plot a set of boxes on an image
By default this function does not show, but only plots an axis
Label column must be numeric!
Expand All @@ -110,12 +111,13 @@ def plot_predictions(image, df, color=None):
image = image.copy()
if not color:
if not ptypes.is_numeric_dtype(df.label):
raise ValueError("Label column is not numeric, please convert to numeric to correctly color image {}".format(df.label.head()))
warnings.warn("No color was provided and the label column is not numeric. Using a single default color.")
color=(0,165,255)

for index, row in df.iterrows():
if not color:
color = label_to_color(row["label"])
cv2.rectangle(image, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), color=color, thickness=1, lineType=cv2.LINE_AA)
cv2.rectangle(image, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), color=color, thickness=thickness, lineType=cv2.LINE_AA)

return image

Expand Down

0 comments on commit 7653161

Please sign in to comment.