Skip to content

Commit

Permalink
Prediction: implement load model from feature proto
Browse files Browse the repository at this point in the history
  • Loading branch information
kechxu committed Jan 5, 2019
1 parent d7e0c97 commit 9f9f192
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 30 deletions.
2 changes: 2 additions & 0 deletions modules/prediction/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ cc_library(
"-DMODULE_NAME=\\\"prediction\\\"",
],
deps = [
"//cyber/common:file",
"//modules/common/adapters:adapter_gflags",
"//modules/prediction/evaluator:evaluator_manager",
"//modules/prediction/predictor:predictor_manager",
"//modules/prediction/proto:offline_features_proto",
"//modules/prediction/scenario:scenario_manager",
"//modules/prediction/util:data_extraction",
],
Expand Down
4 changes: 4 additions & 0 deletions modules/prediction/container/obstacles/obstacle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ void Obstacle::Insert(const PerceptionObstacle& perception_obstacle,
Trim();
}

void Obstacle::InsertFeature(const Feature& feature) {
InsertFeatureToHistory(feature);
}

bool Obstacle::IsInJunction(const std::string& junction_id) {
// TODO(all) Consider if need to use vehicle front rather than position
if (feature_history_.size() == 0) {
Expand Down
6 changes: 6 additions & 0 deletions modules/prediction/container/obstacles/obstacle.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class Obstacle {
void Insert(const perception::PerceptionObstacle& perception_obstacle,
const double timestamp);

/**
* @brief Insert a feature proto message.
* @param feature proto message.
*/
void InsertFeature(const Feature& feature);

/**
* @brief Get the type of perception obstacle's type.
* @return The type pf perception obstacle.
Expand Down
15 changes: 15 additions & 0 deletions modules/prediction/container/obstacles/obstacles_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ void ObstaclesContainer::InsertPerceptionObstacle(
}
}

void ObstaclesContainer::InsertFeatureProto(const Feature& feature) {
if (!feature.has_id()) {
AERROR << "Invalid feature, no ID found.";
return;
}
int id = feature.id();
Obstacle* obstacle_ptr = obstacles_.Get(id);
if (obstacle_ptr != nullptr) {
obstacle_ptr->InsertFeature(feature);
} else {
Obstacle obstacle;
obstacle.InsertFeature(feature);
obstacles_.Put(id, std::move(obstacle));
}
}

void ObstaclesContainer::BuildLaneGraph() {
// Go through every obstacle in the current frame, after some
Expand Down
6 changes: 6 additions & 0 deletions modules/prediction/container/obstacles/obstacles_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class ObstaclesContainer : public Container {
const perception::PerceptionObstacle& perception_obstacle,
const double timestamp);

/**
* @brief Insert a feature proto message into the container
* @param feature proto message
*/
void InsertFeatureProto(const Feature& feature);

/**
* @brief Build lane graph for obstacles
*/
Expand Down
64 changes: 34 additions & 30 deletions modules/prediction/evaluator/evaluator_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ void EvaluatorManager::Run(
AdapterConfig::PERCEPTION_OBSTACLES);
CHECK_NOTNULL(container);

Evaluator* evaluator = nullptr;
for (const auto& perception_obstacle :
perception_obstacles.perception_obstacle()) {
if (!perception_obstacle.has_id()) {
Expand All @@ -132,39 +131,44 @@ void EvaluatorManager::Run(
continue;
}

switch (perception_obstacle.type()) {
case PerceptionObstacle::VEHICLE: {
if (obstacle->HasJunctionFeatureWithExits() &&
!obstacle->IsClosedToJunctionExit()) {
evaluator = GetEvaluator(vehicle_in_junction_evaluator_);
CHECK_NOTNULL(evaluator);
} else if (obstacle->IsOnLane()) {
evaluator = GetEvaluator(vehicle_on_lane_evaluator_);
CHECK_NOTNULL(evaluator);
}
break;
}
case PerceptionObstacle::BICYCLE: {
if (obstacle->IsOnLane()) {
evaluator = GetEvaluator(cyclist_on_lane_evaluator_);
CHECK_NOTNULL(evaluator);
}
break;
}
case PerceptionObstacle::PEDESTRIAN: {
break;
EvaluateObstacle(obstacle);
}
}

void EvaluatorManager::EvaluateObstacle(Obstacle* obstacle) {
Evaluator* evaluator = nullptr;
switch (obstacle->type()) {
case PerceptionObstacle::VEHICLE: {
if (obstacle->HasJunctionFeatureWithExits() &&
!obstacle->IsClosedToJunctionExit()) {
evaluator = GetEvaluator(vehicle_in_junction_evaluator_);
CHECK_NOTNULL(evaluator);
} else if (obstacle->IsOnLane()) {
evaluator = GetEvaluator(vehicle_on_lane_evaluator_);
CHECK_NOTNULL(evaluator);
}
default: {
if (obstacle->IsOnLane()) {
evaluator = GetEvaluator(default_on_lane_evaluator_);
CHECK_NOTNULL(evaluator);
}
break;
break;
}
case PerceptionObstacle::BICYCLE: {
if (obstacle->IsOnLane()) {
evaluator = GetEvaluator(cyclist_on_lane_evaluator_);
CHECK_NOTNULL(evaluator);
}
break;
}
if (evaluator != nullptr) {
evaluator->Evaluate(obstacle);
case PerceptionObstacle::PEDESTRIAN: {
break;
}
default: {
if (obstacle->IsOnLane()) {
evaluator = GetEvaluator(default_on_lane_evaluator_);
CHECK_NOTNULL(evaluator);
}
break;
}
}
if (evaluator != nullptr) {
evaluator->Evaluate(obstacle);
}
}

Expand Down
2 changes: 2 additions & 0 deletions modules/prediction/evaluator/evaluator_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class EvaluatorManager {
*/
void Run(const perception::PerceptionObstacles& perception_obstacles);

void EvaluateObstacle(Obstacle* obstacle);

private:
/**
* @brief Register an evaluator by type
Expand Down
18 changes: 18 additions & 0 deletions modules/prediction/prediction_component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@
#include <algorithm>
#include <vector>

#include "cyber/common/file.h"
#include "cyber/record/record_reader.h"
#include "modules/common/adapters/adapter_gflags.h"
#include "modules/common/util/message_util.h"

#include "modules/prediction/common/feature_output.h"
#include "modules/prediction/common/junction_analyzer.h"
#include "modules/prediction/common/prediction_gflags.h"
#include "modules/prediction/common/prediction_system_gflags.h"
#include "modules/prediction/common/validation_checker.h"
#include "modules/prediction/evaluator/evaluator_manager.h"
#include "modules/prediction/predictor/predictor_manager.h"
#include "modules/prediction/proto/offline_features.pb.h"
#include "modules/prediction/scenario/scenario_manager.h"
#include "modules/prediction/util/data_extraction.h"

Expand Down Expand Up @@ -73,6 +76,21 @@ void PredictionComponent::ProcessOfflineData(const std::string& filename) {
}
}

void PredictionComponent::OfflineProcessFeatureProtoFile(
const std::string& features_proto_file_name) {
auto obstacles_container_ptr = ContainerManager::Instance()->GetContainer<
ObstaclesContainer>(AdapterConfig::PERCEPTION_OBSTACLES);
obstacles_container_ptr->Clear();
Features features;
apollo::cyber::common::GetProtoFromBinaryFile(
features_proto_file_name, &features);
for (const Feature& feature : features.feature()) {
obstacles_container_ptr->InsertFeatureProto(feature);
Obstacle* obstacle_ptr = obstacles_container_ptr->GetObstacle(feature.id());
EvaluatorManager::Instance()->EvaluateObstacle(obstacle_ptr);
}
}

bool PredictionComponent::Init() {
component_start_time_ = Clock::NowInSeconds();

Expand Down
6 changes: 6 additions & 0 deletions modules/prediction/prediction_component.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class PredictionComponent
bool Proc(
const std::shared_ptr<perception::PerceptionObstacles> &) override;

/**
* @brief Load and process feature proto file.
* @param a bin file including a sequence of feature proto.
*/
void OfflineProcessFeatureProtoFile(const std::string& features_proto_file);

private:
void OnLocalization(const localization::LocalizationEstimate &localization);

Expand Down

0 comments on commit 9f9f192

Please sign in to comment.