Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz authored May 27, 2021
1 parent 38119dc commit 5174cbe
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ pip install deepforest-pytorch

# Usage

## Train a model
# Use Benchmark release

```Python
from deepforest import main
m = main.deepforest()
m.use_release()
```

## Train a new model

```Python
m.create_trainer()
m.run_train()
m.trainer.fit(m)
m.evaluate(csv_file=m.config["validation"]["csv_file"], root_dir=m.config["validation"]["root_dir"])
```
[Google colab demo on model training](https://colab.research.google.com/drive/1AJUcw5dEpXeDPHd0sotAz5lpWedFYSIL#offline=true&sandboxMode=true)
Expand All @@ -47,7 +53,7 @@ df = trained_model.predict_file(csv_file, root_dir = os.path.dirname(csv_file))
## Predict a large tile

```Python
prediction = trained_model.predict_tile(raster_path = raster_path,
predicted_boxes = trained_model.predict_tile(raster_path = raster_path,
patch_size = 300,
patch_overlap = 0.5,
return_plot = False)
Expand All @@ -58,7 +64,7 @@ prediction = trained_model.predict_tile(raster_path = raster_path,
```Python
csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
precision, recall = m.evaluate(csv_file, root_dir, iou_threshold = 0.5)
results = m.evaluate(csv_file, root_dir, iou_threshold = 0.5)
```

# Config
Expand All @@ -68,7 +74,7 @@ DeepForest comes with a default config file (deepforest_config.yml) to control t
```Python
from deepforest import main
m = main.deepforest()
m.config["train"]["batch_size"] = 10
m.config["batch_size"] = 10
```
Config parameters are documented [here](https://deepforest-pytorch.readthedocs.io/en/latest/ConfigurationFile.html).

Expand All @@ -82,6 +88,6 @@ cd NeonTreeEvaluation
```
```Python
results = m.evaluate(csv_file = "evaluation/RGB/benchmark_annotations.csv", root_dir = "evaluation/RGB/")
results["recall"]
results["precision"]
results["box_recall"]
results["box_precision"]
```

0 comments on commit 5174cbe

Please sign in to comment.