Skip to content

Commit

Permalink
Merge pull request maciejkula#33 from maciejkula/more_json_tests
Browse files Browse the repository at this point in the history
Add additional JSON serialization tests.
  • Loading branch information
maciejkula authored Aug 30, 2016
2 parents 75512db + 4eb480b commit 95abe0d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
14 changes: 12 additions & 2 deletions src/ensemble/random_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ mod tests {
use rand::{StdRng, SeedableRng};

use bincode;
use rustc_serialize::json;

#[cfg(feature = "all_tests")]
use datasets::newsgroups;
Expand Down Expand Up @@ -418,9 +419,18 @@ mod tests {
let decoded: OneVsRestWrapper<RandomForest> =
bincode::rustc_serialize::decode(&encoded).unwrap();

let test_prediction = decoded.predict(&x_test).unwrap();
let bincode_prediction = decoded.predict(&x_test).unwrap();

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &test_prediction);
// JSON encoding
let encoded = json::encode(&model).unwrap();
let decoded: OneVsRestWrapper<RandomForest> =
json::decode(&encoded).unwrap();

let json_prediction = decoded.predict(&x_test).unwrap();

assert!(allclose(&json_prediction, &bincode_prediction));

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &json_prediction);
}

test_accuracy /= no_splits as f32;
Expand Down
10 changes: 5 additions & 5 deletions src/trees/decision_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1182,18 +1182,18 @@ mod tests {
let decoded: OneVsRestWrapper<DecisionTree> =
bincode::rustc_serialize::decode(&encoded).unwrap();

let test_prediction = decoded.predict(&x_test).unwrap();

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &test_prediction);
let bincode_prediction = decoded.predict(&x_test).unwrap();

// JSON encoding
let encoded = json::encode(&model).unwrap();
let decoded: OneVsRestWrapper<DecisionTree> =
json::decode(&encoded).unwrap();

let test_prediction = decoded.predict(&x_test).unwrap();
let json_prediction = decoded.predict(&x_test).unwrap();

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &test_prediction);
assert!(allclose(&json_prediction, &bincode_prediction));

test_accuracy += accuracy_score(&target.get_rows(&test_idx), &json_prediction);
}

test_accuracy /= no_splits as f32;
Expand Down

0 comments on commit 95abe0d

Please sign in to comment.