Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mondrian Forests #10

Open
wants to merge 60 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
6023dd1
Update ClassifierOutput docstring
MarcoDiFrancesco Apr 11, 2024
feba8a0
Add RegressionOutput to common
MarcoDiFrancesco Apr 11, 2024
c13d3c6
Merge branch 'online-ml:main' into main
MarcoDiFrancesco Apr 11, 2024
308a082
Add boilerplate code for mondrian forest
MarcoDiFrancesco Apr 12, 2024
3ba0e3a
Add keystroke dataset
MarcoDiFrancesco Apr 12, 2024
2f9e03d
Add all functions calls with unimplemented errors
MarcoDiFrancesco Apr 15, 2024
7b63db5
Add predict steps to be refactored
MarcoDiFrancesco Apr 15, 2024
d5bb6db
Add get features function
MarcoDiFrancesco Apr 16, 2024
b5b7ec4
Add Array library
MarcoDiFrancesco Apr 16, 2024
d613df2
Add randomization for cache tests
MarcoDiFrancesco Apr 16, 2024
2174472
Disable test github actions and enable only check
MarcoDiFrancesco Apr 17, 2024
1c91530
Remove verbose from build and test
MarcoDiFrancesco Apr 17, 2024
44cfba4
Add Stats struct and impl
MarcoDiFrancesco Apr 18, 2024
4c6ebe4
Add rust caching in actions
MarcoDiFrancesco Apr 18, 2024
1ccabc4
Split MondrianTree and MondrianForest
MarcoDiFrancesco Apr 22, 2024
ac71b06
Refactor to use Tree Vector indicies instead of pointers
MarcoDiFrancesco Apr 23, 2024
8aad4ed
Change actions cargo.lock to cargo.toml
MarcoDiFrancesco Apr 23, 2024
8c91dd8
Add print function for MondrianTree
MarcoDiFrancesco Apr 23, 2024
6b38849
Adding print functions to mondriantree and node
MarcoDiFrancesco Apr 23, 2024
107354a
Implement and test predict_proba
MarcoDiFrancesco Apr 24, 2024
4385fe8
Add unit test for predict_proba
MarcoDiFrancesco Apr 24, 2024
49d4e3e
Add final implementation of inference (predict_proba)
MarcoDiFrancesco Apr 24, 2024
a16d3e7
Add random distribution to extend mondrian block
MarcoDiFrancesco Apr 25, 2024
de5d67a
Add full extend_mondrian_block implementation
MarcoDiFrancesco Apr 25, 2024
667d35e
Add synthetic dataset and tree integrity tests
MarcoDiFrancesco Apr 25, 2024
f79864d
Fix pointer of grandpa on extend_mondrian_block
MarcoDiFrancesco Apr 26, 2024
989c176
Add recursive repr mondrian forest
MarcoDiFrancesco Apr 26, 2024
75e5feb
Add score function
MarcoDiFrancesco Apr 29, 2024
da4a00a
Remove debug statements
MarcoDiFrancesco Apr 30, 2024
717161f
Adjust code to River behaviour
MarcoDiFrancesco Apr 30, 2024
a9ca4bc
Adapt _go_downwards from River
MarcoDiFrancesco May 3, 2024
ccc9b1d
Update function names from nel215 to River
MarcoDiFrancesco May 3, 2024
30fb86b
Comment debug prints
MarcoDiFrancesco May 3, 2024
a619415
Remove unused imports
MarcoDiFrancesco May 3, 2024
da23d14
Add synthetic dataset download
MarcoDiFrancesco May 3, 2024
85030ad
Rename MondrianForest to MondrianForestClassifier
MarcoDiFrancesco May 6, 2024
c4753f1
Update readme with classification run instructions
MarcoDiFrancesco May 6, 2024
a08f922
Add update_leaf flag to create_leaf
MarcoDiFrancesco May 13, 2024
a00cfe5
Fix mondrian forest classifier test
MarcoDiFrancesco May 13, 2024
4d9ef48
Remove create_leaf flag
MarcoDiFrancesco May 20, 2024
0217db2
Add create leafs when reaching a leaf
MarcoDiFrancesco May 24, 2024
1e5a874
Add assert to check for NaN probability
MarcoDiFrancesco May 24, 2024
6971c21
Revert removal of split_time
MarcoDiFrancesco May 24, 2024
782d1f2
Add test cases
MarcoDiFrancesco May 29, 2024
a5bd895
Remove unused `child_is_on_edge_parent` test case
MarcoDiFrancesco May 29, 2024
3544c28
Add debug statement for overwriting variance aware estimation
MarcoDiFrancesco May 29, 2024
9083d8e
Add synthetic regression target boilerplate
MarcoDiFrancesco Jun 4, 2024
43cce28
Add Classification and Regression division of MF
MarcoDiFrancesco Jun 7, 2024
e58638b
Add regression task and parent_has_finite_values test
MarcoDiFrancesco Jun 11, 2024
fed6daf
Fix child_inside_parent test
MarcoDiFrancesco Jun 11, 2024
760de79
Remove prints in excess
MarcoDiFrancesco Jun 11, 2024
54bb202
Add regression metrics
MarcoDiFrancesco Jun 12, 2024
0d74d3f
Fix test keystroke dataset
MarcoDiFrancesco Jun 12, 2024
c60b381
Change description of synthetic dataset
MarcoDiFrancesco Jun 12, 2024
ec2109a
Add baseline comparison for regression
MarcoDiFrancesco Jun 24, 2024
b77ba69
Add machine degradation dataset
MarcoDiFrancesco Jul 9, 2024
a6c1b8b
Add genesis demostrator dataset
MarcoDiFrancesco Jul 10, 2024
4a4b9f5
Update machine degradation with redirect
MarcoDiFrancesco Jul 10, 2024
23c109e
Update src/datasets/synthetic_regression.rs
smastelini Jul 29, 2024
38e64ee
Update src/datasets/synthetic.rs
smastelini Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Comment debug prints
  • Loading branch information
MarcoDiFrancesco committed May 3, 2024
commit 30fb86b3307aa215c77996f50310395bb2bc408d
10 changes: 3 additions & 7 deletions examples/classification/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ use light_river::classification::mondrian_tree::MondrianTree;
use light_river::common::ClassifierOutput;
use light_river::common::ClassifierTarget;
use light_river::datasets::synthetic::Synthetic;
use light_river::metrics::rocauc::ROCAUC;
use light_river::metrics::traits::ClassificationMetric;
use light_river::stream::data_stream::DataStream;
use light_river::stream::iter_csv::IterCsv;
use ndarray::{s, Array1};
use num::ToPrimitive;
Expand Down Expand Up @@ -43,17 +40,16 @@ fn main() {
let window_size: usize = 1000;
let n_trees: usize = 1;

let transactions_f = Synthetic::load_data().unwrap();
let transactions_f = Synthetic::load_data();
let features = get_features(transactions_f);

let transactions_c = Synthetic::load_data().unwrap();
let transactions_c = Synthetic::load_data();
let labels = get_labels(transactions_c);
println!("labels: {labels:?}, features: {features:?}");
let mut mf: MondrianForest<f32> = MondrianForest::new(window_size, n_trees, &features, &labels);

let mut score_total = 0.0;

let transactions = Synthetic::load_data().unwrap();
let transactions = Synthetic::load_data();
for (idx, transaction) in transactions.enumerate() {
let data = transaction.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion src/classification/mondrian_forest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<F: FType> MondrianForest<F> {
"Probability should not be NaN. Found: {:?}.",
probs.to_vec()
);
total_probs += &probs; // Assuming `probs` is an Array1<F>
total_probs += &probs;
}

// Average the probabilities by the number of trees
Expand Down
24 changes: 4 additions & 20 deletions src/classification/mondrian_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ impl<F: FType> Node<F> {
/// e.g. y=2, stats.counts=[0, 1, 10] -> False
/// e.g. y=2, stats.counts=[0, 0, 10] -> True
/// e.g. y=1, stats.counts=[0, 0, 10] -> False
///
/// From: River function
pub fn is_dirac(&self, y: usize) -> bool {
return self.stats.counts.sum() == self.stats.counts[y];
}
Expand Down Expand Up @@ -123,11 +121,6 @@ impl<F: FType> Stats<F> {
}
/// Return probabilities of sample 'x' belonging to each class.
///
/// e.g. probs: [0.1, 0.2, 0.7]
///
/// TODO: Remove the assert that check for exact values, I was testing if unit tests make sense, but as
/// shown below this does not show the error. The function is just too complex.
///
/// # Example
/// ```
/// use light_river::classification::alias::FType;
Expand All @@ -146,15 +139,10 @@ impl<F: FType> Stats<F> {
///
/// let x = Array1::from_vec(vec![1.5, 3.0]);
/// let probs = stats.predict_proba(&x);
/// let expected = vec![0.998075, 0.001924008, 0.0];
/// assert!(
/// (probs.clone() - Array1::from_vec(expected)).mapv(|a: f32| a.abs()).iter().all(|&x| x < 1e-4),
/// "Probabilities do not match expected values"
/// );
/// // Check all values inside [0, 1] range
/// assert!(probs.clone().iter().all(|&x| x >= 0.0 && x <= 1.0), "Probabilities should be in [0, 1] range");
/// // Check sum is 1
/// assert!((probs.clone().sum() - 1.0).abs() < 1e-4, "Sum of probabilities should be 1");
/// assert!((probs.clone().sum() - 1.0f32).abs() < 1e-4, "Sum of probabilities should be 1");
/// ```
pub fn predict_proba(&self, x: &Array1<F>) -> Array1<F> {
let mut probs = Array1::zeros(self.num_labels);
Expand All @@ -169,16 +157,14 @@ impl<F: FType> Stats<F> {
.zip(self.counts.iter())
.enumerate()
{
// println!("predict_proba() - mid - index: {:?}, sum: {:?}, sq_sum: {:?}, count: {:?}", index, sum.to_vec(), sq_sum.to_vec(), count);
let epsilon = F::epsilon(); // F::from_f32(1e-9).unwrap();
let epsilon = F::epsilon();
let count_f = F::from_usize(count).unwrap();
let avg = &sum / count_f;
let var = (&sq_sum / count_f) - (&avg * &avg) + epsilon;
let sigma = (&var * count_f) / (count_f - F::one() + epsilon);
// println!("predict_proba() - mid - avg: {:?}, var: {:?}, sigma: {:?}", avg.to_vec(), var.to_vec(), sigma.to_vec());
let pi = F::from_f32(std::f32::consts::PI).unwrap() * F::from_f32(2.0).unwrap();
let z = pi.powi(x.len() as i32) * sigma.mapv(|s| s * s).sum().sqrt();
// Same as dot product
// Dot product
let dot_feature = (&(x - &avg) * &(x - &avg)).sum();
let dot_sigma = (&sigma * &sigma).sum();
let exponent = -F::from_f32(0.5).unwrap() * dot_feature / dot_sigma;
Expand All @@ -192,9 +178,6 @@ impl<F: FType> Stats<F> {
probs[index] = prob;
}

// println!("predict_proba() post - probs: {:?}", probs.to_vec());
// println!();

// Check at least one probability is non-zero. Otherwise we have division by zero.
assert!(
!probs.iter().all(|&x| x == F::zero()),
Expand All @@ -205,6 +188,7 @@ impl<F: FType> Stats<F> {
for prob in probs.iter_mut() {
*prob /= sum_prob;
}
// println!("predict_proba() post - probs: {:?}", probs.to_vec());
probs
}
}
81 changes: 32 additions & 49 deletions src/classification/mondrian_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl<F: FType> MondrianTree<F> {
/// Note: In Nel215 codebase should work on multiple records, here it's
/// working only on one, so it's the same as "predict()".
pub fn predict_proba(&self, x: &Array1<F>) -> Array1<F> {
println!("predict_proba() - tree size: {}", self.nodes.len());
// println!("predict_proba() - tree size: {}", self.nodes.len());
// self.test_tree();
self.predict(x, self.root.unwrap(), F::one())
}
Expand Down Expand Up @@ -155,10 +155,10 @@ impl<F: FType> MondrianTree<F> {
extensions_sum: F,
) -> F {
if self.nodes[node_idx].is_dirac(y) {
println!(
"go_downwards() - node: {node_idx} - extensions_sum: {:?} - all same class",
extensions_sum
);
// println!(
// "go_downwards() - node: {node_idx} - extensions_sum: {:?} - all same class",
// extensions_sum
// );
return F::zero();
}

Expand All @@ -167,10 +167,10 @@ impl<F: FType> MondrianTree<F> {

// From River: If the node is a leaf we must split it
if self.nodes[node_idx].is_leaf {
println!(
"go_downwards() - node: {node_idx} - extensions_sum: {:?} - split is_leaf",
extensions_sum
);
// println!(
// "go_downwards() - node: {node_idx} - extensions_sum: {:?} - split is_leaf",
// extensions_sum
// );
return split_time;
}

Expand All @@ -180,19 +180,18 @@ impl<F: FType> MondrianTree<F> {
let child_time = self.nodes[child_idx].time;
// 2. We check if splitting time occurs before child creation time
if split_time < child_time {
println!(
"go_downwards() - node: {node_idx} - extensions_sum: {:?} - split mid tree",
extensions_sum
);
// Go to next child????
// println!(
// "go_downwards() - node: {node_idx} - extensions_sum: {:?} - split mid tree",
// extensions_sum
// );
return split_time;
}
println!("go_downwards() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
// println!("go_downwards() - node: {node_idx} - extensions_sum: {:?} - not increased enough to split (mid node)", extensions_sum);
} else {
println!(
"go_downwards() - node: {node_idx} - extensions_sum: {:?} - not outside box",
extensions_sum
);
// println!(
// "go_downwards() - node: {node_idx} - extensions_sum: {:?} - not outside box",
// extensions_sum
// );
}

F::zero()
Expand Down Expand Up @@ -345,58 +344,42 @@ impl<F: FType> MondrianTree<F> {
None => Some(self.create_leaf(x, y, None, F::zero())),
Some(root_idx) => Some(self.go_downwards(root_idx, x, y)),
};
println!("partial_fit() tree post {}", self);
// println!("partial_fit() tree post {}", self);
}

fn fit(&self) {
unimplemented!("Make the program first work with 'partial_fit', then implement this")
}

/// Function in River: "go_downwards()"
///
/// Recursive function to predict probabilities.
fn predict(&self, x: &Array1<F>, node_idx: usize, p_not_separated_yet: F) -> Array1<F> {
let node = &self.nodes[node_idx];

// Step 1: Calculate the time feature from the parent node.
let d = node.time - self.get_parent_time(node_idx);

// Step 2: If 'x' is outside the box, calculate distance of 'x' from the box
let dist_max = (x - &node.max_list).mapv(|v| F::max(v, F::zero()));
let dist_min = (&node.min_list - x).mapv(|v| F::max(v, F::zero()));
let eta = dist_min.sum() + dist_max.sum();
// It works, but check again once 'max_list' and 'min_list' are not 0s
// println!("x: {:?}, node.max_list {:?}, max(max_list) {:?}, node.min_list {:?}, max(min_list) {:?}",
// x.to_vec(), node.max_list.to_vec(), dist_max.to_vec(), node.min_list.to_vec(), dist_min.to_vec());

// Step 3: Probability 'p' of the box not splitting.
// Probability 'p' of the box not splitting.
// eta (box dist): larger distance, more prob of splitting
// d (time diff with parent): more dist with parent, more prob of splitting
let p = F::one() - (-d * eta).exp();
// println!("predict() -> pre create_result() - node_idx {}", node.stats);
// d (time delta with parent): more dist with parent, more prob of splitting
let p = {
let d = node.time - self.get_parent_time(node_idx);
// If 'x' is outside the box, calculate distance of 'x' from the box
let dist_max = (x - &node.max_list).mapv(|v| F::max(v, F::zero()));
let dist_min = (&node.min_list - x).mapv(|v| F::max(v, F::zero()));
let eta = dist_min.sum() + dist_max.sum();
F::one() - (-d * eta).exp()
};

// Step 4: Generate a result for the current node using its statistics.
// Generate a result for the current node using its statistics.
let res = node.stats.create_result(x, p_not_separated_yet * p);
// println!("predict() -> post create_result() - node.stats: {}", node.stats);
// println!(
// "predict() - res: {:?}, p_not_separated_yet: {:?}, p: {:?}",
// res, p_not_separated_yet, p
// );

let w = p_not_separated_yet * (F::one() - p);
if node.is_leaf {
let w = p_not_separated_yet * (F::one() - p);
let res2 = node.stats.create_result(x, w);
// println!("predict() - ischild - res: {:?}, res2: {:?}", res.to_vec(), res2.to_vec());
return res + res2;
} else {
let child_idx = if x[node.feature] <= node.threshold {
node.left
} else {
node.right
};
let child_res =
self.predict(x, child_idx.unwrap(), p_not_separated_yet * (F::one() - p));
// println!("predict() - notchild - res: {:?}, child_res: {:?}", res.to_vec(), child_res.to_vec());
let child_res = self.predict(x, child_idx.unwrap(), w);
return res + child_res;
}
}
Expand Down
9 changes: 3 additions & 6 deletions src/datasets/synthetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@ use std::{fs::File, path::Path};
/// Add 'synthetic.csv' to project root directory.
pub struct Synthetic;
impl Synthetic {
pub fn load_data() -> Result<IterCsv<f32, File>, Box<dyn std::error::Error>> {
pub fn load_data() -> IterCsv<f32, File> {
// let file_name = "syntetic_dataset_paper.csv";
let file_name = "syntetic_dataset_int.csv";
let file = File::open(file_name)?;
let file = File::open(file_name).unwrap();
let y_cols = Some(Target::Name("label".to_string()));
match IterCsv::<f32, File>::new(file, y_cols) {
Ok(x) => Ok(x),
Err(e) => Err(Box::new(e)),
}
IterCsv::<f32, File>::new(file, y_cols).unwrap()
}
}
Loading