Skip to content

Commit

Permalink
[luci] Introduce DeadNodeQueryService (Samsung#642)
Browse files Browse the repository at this point in the history
This commit introduce DeadNodeQueryService to luci

ONE-DCO-1.0-Signed-off-by: seongwoo <[email protected]>
  • Loading branch information
mhs4670go authored May 11, 2020
1 parent f971386 commit 1a32552
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 2 deletions.
2 changes: 2 additions & 0 deletions compiler/luci/lang/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ target_include_directories(luci_lang PRIVATE src)
target_include_directories(luci_lang PUBLIC include)
target_link_libraries(luci_lang PUBLIC loco)
target_link_libraries(luci_lang PUBLIC oops)
target_link_libraries(luci_lang PRIVATE logo)
target_link_libraries(luci_lang PRIVATE nncc_common)

install(TARGETS luci_lang DESTINATION lib)
Expand All @@ -20,3 +21,4 @@ nnas_find_package(GTest REQUIRED)
GTest_AddTest(luci_lang_test ${TESTS})
target_include_directories(luci_lang_test PRIVATE src)
target_link_libraries(luci_lang_test luci_lang)
target_link_libraries(luci_lang_test logo)
27 changes: 27 additions & 0 deletions compiler/luci/lang/src/CircleDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <loco/IR/GraphInputIndex.h>
#include <loco/IR/GraphOutputIndex.h>

#include <logo/DeadNodeQueryService.h>

#include <cassert>
#include <memory>

Expand Down Expand Up @@ -68,6 +70,30 @@ struct GoiQueryServiceImpl final : public loco::GraphOutputIndexQueryService
}
};

struct DeadNodeQueryServiceImpl final : public logo::DeadNodeQueryService
{
bool isDeadNode(loco::Node *node) final
{
auto g = node->graph();
auto input_nodes_vec = loco::input_nodes(g);
auto output_nodes_vec = loco::output_nodes(g);

auto input_nodes = std::set<loco::Node *>(input_nodes_vec.begin(), input_nodes_vec.end());
auto output_nodes = std::set<loco::Node *>(output_nodes_vec.begin(), output_nodes_vec.end());
auto active_nodes = loco::active_nodes(output_nodes_vec);

if (active_nodes.find(node) != active_nodes.end())
return false;
// input and output nodes are not dead node even if it is not active.
if (input_nodes.find(node) != input_nodes.end())
return false;
if (output_nodes.find(node) != output_nodes.end())
return false;

return true;
}
};

} // namespace

namespace luci
Expand All @@ -77,6 +103,7 @@ CircleDialect::CircleDialect()
{
service<loco::GraphInputIndexQueryService>(std::make_unique<GiiQueryServiceImpl>());
service<loco::GraphOutputIndexQueryService>(std::make_unique<GoiQueryServiceImpl>());
service<logo::DeadNodeQueryService>(std::make_unique<DeadNodeQueryServiceImpl>());
}

loco::Dialect *CircleDialect::get(void)
Expand Down
58 changes: 58 additions & 0 deletions compiler/luci/lang/src/CircleDialect.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
*/

#include "luci/IR/CircleDialect.h"
#include "luci/IR/CircleNodes.h"

#include <loco.h>
#include <logo/DeadNodeQueryService.h>

#include <gtest/gtest.h>

Expand All @@ -27,3 +31,57 @@ TEST(CircleDialectTest, get_P)
// The return value SHOULD be stable across multiple invocations
ASSERT_EQ(luci::CircleDialect::get(), d);
}

TEST(CircleDialectTest, check_if_dead_node_service)
{
/**
* [CircleInput1] [CircleInput2] [CircleInput3]
* \ / (dangling input)
* \ /
* [CircleAdd] [CircleBatchMatMul]
* | (dangling node)
* |
* [CircleOutput1] [CircleOutput2]
* (dangling output)
*/
auto g = loco::make_graph();

auto graph_input1 = g->inputs()->create();
auto circle_input1 = g->nodes()->create<luci::CircleInput>();
circle_input1->index(graph_input1->index());

auto graph_input2 = g->inputs()->create();
auto circle_input2 = g->nodes()->create<luci::CircleInput>();
circle_input2->index(graph_input2->index());

// dangling output
auto graph_input3 = g->inputs()->create();
auto dangling_input = g->nodes()->create<luci::CircleInput>();
dangling_input->index(graph_input3->index());

auto active_node = g->nodes()->create<luci::CircleAdd>();
active_node->x(circle_input1);
active_node->y(circle_input2);

auto dangling_node = g->nodes()->create<luci::CircleBatchMatMul>();

auto graph_output1 = g->outputs()->create();
auto circle_output1 = g->nodes()->create<luci::CircleOutput>();
circle_output1->index(graph_output1->index());
circle_output1->from(active_node);

// dangling output
auto graph_output2 = g->outputs()->create();
auto circle_output2 = g->nodes()->create<luci::CircleOutput>();
circle_output2->index(graph_output2->index());

auto service = active_node->dialect()->service<logo::DeadNodeQueryService>();

ASSERT_TRUE(service->isDeadNode(dangling_node));
ASSERT_FALSE(service->isDeadNode(dangling_input));
ASSERT_FALSE(service->isDeadNode(active_node));
ASSERT_FALSE(service->isDeadNode(circle_input1));
ASSERT_FALSE(service->isDeadNode(circle_input2));
ASSERT_FALSE(service->isDeadNode(circle_output1));
ASSERT_FALSE(service->isDeadNode(circle_output2));
}
4 changes: 2 additions & 2 deletions compiler/luci/pass/src/CircleOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "luci/Pass/TypeInferencePass.h"

// logo passes
#include <logo/RemoveDeadNodePass.h>
#include <logo/RemoveDeadNodeWithQueryPass.h>

#include "ProgressReporter.h"

Expand Down Expand Up @@ -89,7 +89,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const
// Shape inference is needed for added nodes doing above transformations
phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
phase.emplace_back(std::make_unique<logo::RemoveDeadNodePass>());
phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
/* TRANSFORM DECLARATION END */

ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/requires.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
require("loco")
require("locop")
require("logo")
require("logo-core")
require("mio-circle")
require("oops")
Expand Down

0 comments on commit 1a32552

Please sign in to comment.