Skip to content

Commit

Permalink
update checker
Browse files Browse the repository at this point in the history
  • Loading branch information
aksnzhy committed Nov 4, 2017
1 parent 36dc2e2 commit 8fd41e3
Show file tree
Hide file tree
Showing 13 changed files with 477 additions and 397 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ add_subdirectory(src/data)
add_subdirectory(src/reader)
add_subdirectory(src/score)
add_subdirectory(src/loss)
#add_subdirectory(src/solver)
add_subdirectory(src/solver)
85 changes: 85 additions & 0 deletions src/base/format_print.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//------------------------------------------------------------------------------
// Copyright (c) 2016 by contributors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//------------------------------------------------------------------------------

/*
Author: Chao Ma ([email protected])
This file defines the facilities for format printing.
*/

#ifndef XLEARN_BASE_FORMAT_PRINT_H_
#define XLEARN_BASE_FORMAT_PRINT_H_

#include <iostream>
#include <vector>
#include <string>

#include "src/base/common.h"

typedef std::vector<std::string> StringList;
typedef std::vector<int> IntList;

//------------------------------------------------------------------------------
// Example:
//
// column -> "Name", "ID", "Count", "Price"
// width -> 10, 10, 10, 10
//
// Output:
//
// Name ID Count Price
// Fruit 0x101 50 5.27
// Juice 0x102 20 8.73
// Meat 0x104 30 10.13
//------------------------------------------------------------------------------
inline void print_row(const StringList& column, const IntList& width) {
CHECK_EQ(column.size(), width.size());
for (size_t i = 0; i < column.size(); ++i) {
std::cout.width(width[i]);
std::cout << column[i];
}
std::cout << "\n";
}

//------------------------------------------------------------------------------
// Example:
//
// std -> "Hello World !"
//
// Output:
//
// -----------------
// | Hello World ! |
// -----------------
//------------------------------------------------------------------------------
inline void print_block(const std::string& str) {
CHECK_NE(str.empty(), true);
// Add two space and two lines
size_t size = str.size() + 4;
for (size_t i = 0; i < size; ++i) {
std::cout << "-";
}
std::cout << "\n";
std::cout << "| ";
std::cout << str;
std::cout << " |";
for (size_t i = 0; i < size; ++i) {
std::cout << "-";
}
std::cout << "\n";
}

#endif // XLEARN_BASE_FORMAT_PRINT_H_
78 changes: 78 additions & 0 deletions src/base/system.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//------------------------------------------------------------------------------
// Copyright (c) 2016 by contributors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//------------------------------------------------------------------------------

/*
Author: Chao Ma ([email protected])
This file defines several system functions.
*/

#ifndef XLEARN_BASE_SYSTEM_H_
#define XLEARN_BASE_SYSTEM_H_

#include <sys/utsname.h>
#include <unistd.h>

#include <string>

#include "src/base/common.h"
#include "src/base/stringprint.h"

// Get host name
std::string get_host_name() {
struct utsname buf;
if (0 != uname(&buf)) {
*buf.nodename = '\0';
}
return std::string(buf.nodename);
}

// Get user name
std::string get_user_name() {
const char* username = getenv("USER");
return username != NULL ? username : getenv("USERNAME");
}

// Get current system time
std::string print_current_time() {
time_t current_time = time(NULL);
struct tm broken_down_time;
CHECK(localtime_r(&current_time, &broken_down_time) == &broken_down_time);
return StringPrintf("%04d%02d%02d-%02d%02d%02d",
1900 + broken_down_time.tm_year,
1 + broken_down_time.tm_mon,
broken_down_time.tm_mday,
broken_down_time.tm_hour,
broken_down_time.tm_min,
broken_down_time.tm_sec);
}

// The log file name = base + host_name + username +
// date_time + process_id
std::string get_log_file() {
CHECK(!hyper_param_.log_file.empty());
std::string filename_prefix;
SStringPrintf(&filename_prefix,
"%s.%s.%s.%s.%u",
hyper_param_.log_file.c_str(),
get_host_name().c_str(),
get_user_name().c_str(),
print_current_time().c_str(),
getpid());
return filename_prefix;
}

#endif // XLEARN_BASE_SYSTEM_H_
68 changes: 0 additions & 68 deletions src/data/hyper_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,74 +113,6 @@ struct HyperParam {
/* True for using early-stop and
False for not */
bool early_stop = true;

// Check and fix the conflict of hyper-parameters
bool CheckConflict(std::string& err_info) {
err_info.clear();
bool bo = true;
// Confict for on-disk training
if (this->on_disk) {
if (this->cross_validation) {
err_info += "[Warning] On-disk training doesn't support "
"cross-validation and xLearn will disable it. \n";
this->cross_validation = false;
bo = false;
}
}
// Conflict for cross-validation
if (this->cross_validation) {
if (this->early_stop) {
err_info += "[Warning] cross-validation doesn't support "
"early-stopping and xLearn will disable it. \n";
this->early_stop = false;
bo = false;
}
if (!this->validate_set_file.empty()) {
err_info += "[Warning] xLearn has already been set to use "
"cross-validation and will ignore the validation file. \n";
this->validate_set_file.clear();
bo = false;
}
if (this->quiet) {
err_info += "[Warning] Quiet training cannot be used under "
"cross-validation. \n";
this->quiet = false;
bo = false;
}
}
// Conflict for early-stop
if (this->early_stop) {
if (this->validate_set_file.empty()) {
err_info += "[Warning] The validation file cannot be empty when "
"setting early-stopping. \n";
this->early_stop = false;
bo = false;
}
}
// Conflict for metric
if (this->loss_func.compare("cross-entropy") == 0) {
if (this->metric.compare("mae") == 0 ||
this->metric.compare("rmsd") == 0 ||
this->metric.compare("mape") == 0) {
err_info += "[Warning] The " + this->metric + " can only be used "
"in regression tasks. Change it to -x acc .\n";
this->metric = "acc";
bo = false;
}
}
if (this->loss_func.compare("squared") == 0) {
if (this->metric.compare("acc") == 0 ||
this->metric.compare("prec") == 0 ||
this->metric.compare("recall") == 0 ||
this->metric.compare("f1") == 0) {
err_info += "[Warning] The " + this->metric + " can only be used "
"in classification tasks. Change it to -x mae .\n";
this->metric = "mae";
bo = false;
}
}
return bo;
}
};

} // namespace XLEARN
Expand Down
47 changes: 35 additions & 12 deletions src/loss/metric_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ TEST(AccMetricTest, acc_test) {
metric.Accumulate(Y, pred);
metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, (1.0 / 4.0));
EXPECT_EQ(metric.metric_type(), "Accuarcy");
}

TEST(PrecMetricTest, prec_test) {
Expand Down Expand Up @@ -92,6 +93,7 @@ TEST(PrecMetricTest, prec_test) {
metric.Accumulate(Y, pred);
metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, (1.0 / 3.0));
EXPECT_EQ(metric.metric_type(), "Precision");
}

TEST(RecallMetricTest, recall_test) {
Expand Down Expand Up @@ -124,6 +126,7 @@ TEST(RecallMetricTest, recall_test) {
metric.Accumulate(Y, pred);
metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, (1.0 / 2.0));
EXPECT_EQ(metric.metric_type(), "Recall");
}

TEST(F1MetricTest, f1_test) {
Expand Down Expand Up @@ -156,6 +159,7 @@ TEST(F1MetricTest, f1_test) {
metric.Accumulate(Y, pred);
metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, (2.0 / 4.0));
EXPECT_EQ(metric.metric_type(), "F1");
}

TEST(MAEMetricTest, mae_test) {
Expand Down Expand Up @@ -188,6 +192,7 @@ TEST(MAEMetricTest, mae_test) {
metric.Accumulate(Y, pred);
metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, 12.5);
EXPECT_EQ(metric.metric_type(), "MAE");
}

TEST(MAPEMetricTest, mape_test) {
Expand Down Expand Up @@ -220,26 +225,27 @@ TEST(MAPEMetricTest, mape_test) {
metric.Accumulate(Y, pred);
metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, 0.478260869);
EXPECT_EQ(metric.metric_type(), "MAPE");
}

TEST(RSMDMetricTest, rsmd_test) {
TEST(RMSDMetricTest, rsmd_test) {
std::vector<real_t> Y;
Y.push_back(12);
Y.push_back(12);
Y.push_back(12);
Y.push_back(12);
Y.push_back(2);
Y.push_back(2);
Y.push_back(2);
Y.push_back(2);
std::vector<real_t> pred;
pred.push_back(11);
pred.push_back(11);
pred.push_back(11);
pred.push_back(11);
MAPEMetric metric;
pred.push_back(1);
pred.push_back(1);
pred.push_back(1);
pred.push_back(1);
RMSDMetric metric;
size_t threadNumber = std::thread::hardware_concurrency();
ThreadPool* pool = new ThreadPool(threadNumber);
metric.Initialize(pool);
metric.Accumulate(Y, pred);
real_t metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, sqrt(0.006944444444));
EXPECT_FLOAT_EQ(metric_val, 1.0);
metric.Reset();
Y[0] = 23;
Y[1] = 23;
Expand All @@ -251,7 +257,24 @@ TEST(RSMDMetricTest, rsmd_test) {
pred[3] = 12;
metric.Accumulate(Y, pred);
metric_val = metric.GetMetric();
EXPECT_FLOAT_EQ(metric_val, sqrt(0.228733459357));
EXPECT_FLOAT_EQ(metric_val, sqrt(121));
EXPECT_EQ(metric.metric_type(), "RMSD");
}

Metric* CreateMetric(const char* format_name) {
return CREATE_METRIC(format_name);
}

TEST(MetricTest, Create_Metric) {
EXPECT_TRUE(CreateMetric("acc") != NULL);
EXPECT_TRUE(CreateMetric("prec") != NULL);
EXPECT_TRUE(CreateMetric("recall") != NULL);
EXPECT_TRUE(CreateMetric("f1") != NULL);
EXPECT_TRUE(CreateMetric("mae") != NULL);
EXPECT_TRUE(CreateMetric("mape") != NULL);
EXPECT_TRUE(CreateMetric("rmsd") != NULL);
EXPECT_TRUE(CreateMetric("") == NULL);
EXPECT_TRUE(CreateMetric("unknow_name") == NULL);
}

} // namespace xLearn
12 changes: 6 additions & 6 deletions src/solver/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Build library solver
add_library(solver checker.cc trainer.cc solver.cc inference.cc)
add_library(solver checker.cc)

# Build xlearn exe
set(LIBS solver base data loss reader score)
# set(LIBS solver base data loss reader score)

add_executable(xlearn_train train_main.cc)
target_link_libraries(xlearn_train ${LIBS})
# add_executable(xlearn_train train_main.cc)
# target_link_libraries(xlearn_train ${LIBS})

add_executable(xlearn_predict predict_main.cc)
target_link_libraries(xlearn_predict ${LIBS})
# add_executable(xlearn_predict predict_main.cc)
# target_link_libraries(xlearn_predict ${LIBS})

# Install library and header files
install(TARGETS solver DESTINATION lib/solver)
Expand Down
Loading

0 comments on commit 8fd41e3

Please sign in to comment.