diff --git a/deepforest/main.py b/deepforest/main.py index 3619c9db..0f5147ac 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -204,7 +204,7 @@ def val_dataloader(self): return loader - def predict_image(self, image=None, path=None, return_plot=False): + def predict_image(self, image=None, path=None, return_plot=False, color=None, thickness=1): """Predict a single image with a deepforest model Args: @@ -242,7 +242,9 @@ def predict_image(self, image=None, path=None, return_plot=False): image=image, return_plot=return_plot, device=self.current_device, - iou_threshold=self.config["nms_thresh"]) + iou_threshold=self.config["nms_thresh"], + color=color, + thickness=thickness) #Set labels to character from numeric if returning boxes df if not return_plot: @@ -251,7 +253,7 @@ def predict_image(self, image=None, path=None, return_plot=False): return result - def predict_file(self, csv_file, root_dir, savedir=None): + def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1): """Create a dataset and predict entire annotation file Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position. @@ -273,7 +275,9 @@ def predict_file(self, csv_file, root_dir, savedir=None): root_dir=root_dir, savedir=savedir, device=self.current_device, - iou_threshold=self.config["nms_thresh"]) + iou_threshold=self.config["nms_thresh"], + color=color, + thickness=thickness) #Set labels to character from numeric result["label"] = result.label.apply(lambda x: self.numeric_to_label_dict[x]) @@ -289,7 +293,9 @@ def predict_tile(self, return_plot=False, use_soft_nms=False, sigma=0.5, - thresh=0.001): + thresh=0.001, + color=None, + thickness=1): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. @@ -331,7 +337,9 @@ def predict_tile(self, use_soft_nms=use_soft_nms, sigma=sigma, thresh=thresh, - device=self.current_device) + device=self.current_device, + color=color, + thickness=thickness) #edge case, if no boxes predictioned return None if result is None: diff --git a/deepforest/predict.py b/deepforest/predict.py index 274959dc..95b38016 100644 --- a/deepforest/predict.py +++ b/deepforest/predict.py @@ -15,7 +15,7 @@ from deepforest import visualize from deepforest import dataset -def predict_image(model, image, return_plot, device, iou_threshold=0.1): +def predict_image(model, image, return_plot, device, iou_threshold=0.1, color=None, thickness=1): """Predict an image with a deepforest model Args: @@ -54,14 +54,14 @@ def predict_image(model, image, return_plot, device, iou_threshold=0.1): image = np.rollaxis(image, 0, 3) image = image[:,:,::-1] * 255 image = image.astype("uint8") - image = visualize.plot_predictions(image, df) + image = visualize.plot_predictions(image, df, color=color, thickness=thickness) return image else: return df -def predict_file(model, csv_file, root_dir, savedir, device, iou_threshold=0.1): +def predict_file(model, csv_file, root_dir, savedir, device, iou_threshold=0.1, color=(0,165,255), thickness=1): """Create a dataset and predict entire annotation file Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position. @@ -109,7 +109,7 @@ def predict_file(model, csv_file, root_dir, savedir, device, iou_threshold=0.1): #Plot annotations if they exist annotations = df[df.image_path == paths[index]] - image = visualize.plot_predictions(image, annotations, color=(0,165,255)) + image = visualize.plot_predictions(image, annotations, color=color, thickness=thickness) cv2.imwrite("{}/{}.png".format(savedir, os.path.splitext(paths[index])[0]), image) @@ -131,7 +131,9 @@ def predict_tile(model, return_plot=False, use_soft_nms=False, sigma=0.5, - thresh=0.001): + thresh=0.001, + color=None, + thickness=1): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. @@ -235,7 +237,7 @@ def predict_tile(model, if return_plot: # Draw predictions on BGR image = image[:,:,::-1] - image = visualize.plot_predictions(image, mosaic_df) + image = visualize.plot_predictions(image, mosaic_df, color=color, thickness=thickness) # Mantain consistancy with predict_image return image else: diff --git a/deepforest/visualize.py b/deepforest/visualize.py index 1c376934..225d0f5c 100644 --- a/deepforest/visualize.py +++ b/deepforest/visualize.py @@ -9,8 +9,9 @@ import pandas.api.types as ptypes import cv2 import random +import warnings -def view_dataset(ds, savedir=None): +def view_dataset(ds, savedir=None, color=None, thickness=1): """Plot annotations on images for debugging purposes Args: ds: a deepforest pytorch dataset, see deepforest.dataset or deepforest.load_dataset() to start from a csv file @@ -20,7 +21,7 @@ def view_dataset(ds, savedir=None): image_path, image, targets = i df = format_boxes(targets[0], scores=False) image = np.moveaxis(image[0].numpy(),0,2) - image = plot_predictions(image, df) + image = plot_predictions(image, df, color=color, thickness=thickness) if savedir: cv2.imwrite("{}/{}".format(savedir, image_path[0]), image) @@ -91,7 +92,7 @@ def plot_prediction_dataframe(df, root_dir, ground_truth=None, savedir=None): return written_figures -def plot_predictions(image, df, color=None): +def plot_predictions(image, df, color=None, thickness=1): """Plot a set of boxes on an image By default this function does not show, but only plots an axis Label column must be numeric! @@ -110,12 +111,13 @@ def plot_predictions(image, df, color=None): image = image.copy() if not color: if not ptypes.is_numeric_dtype(df.label): - raise ValueError("Label column is not numeric, please convert to numeric to correctly color image {}".format(df.label.head())) + warnings.warn("No color was provided and the label column is not numeric. Using a single default color.") + color=(0,165,255) for index, row in df.iterrows(): if not color: color = label_to_color(row["label"]) - cv2.rectangle(image, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), color=color, thickness=1, lineType=cv2.LINE_AA) + cv2.rectangle(image, (int(row["xmin"]), int(row["ymin"])), (int(row["xmax"]), int(row["ymax"])), color=color, thickness=thickness, lineType=cv2.LINE_AA) return image