diff --git a/.github/workflows/run_windows.yml b/.github/workflows/run_windows.yml index d4d7c2d03..651347d97 100644 --- a/.github/workflows/run_windows.yml +++ b/.github/workflows/run_windows.yml @@ -274,6 +274,11 @@ jobs: - name: EEG Metrics CI C# Test run: .\csharp-package\brainflow\tests\eeg_metrics\bin\Release\test.exe shell: cmd + - name: EEG Metrics CI Rust Test + run: | + cd %GITHUB_WORKSPACE%\rust-package\brainflow + cargo run --example=eeg_metrics + shell: cmd # Start Deploy Stage - name: Install Python AWS Tools run: | diff --git a/rust-package/brainflow/examples/eeg_metrics.rs b/rust-package/brainflow/examples/eeg_metrics.rs new file mode 100644 index 000000000..b51cef2c6 --- /dev/null +++ b/rust-package/brainflow/examples/eeg_metrics.rs @@ -0,0 +1,52 @@ +use std::{thread, time::Duration}; + +use brainflow::{ + board_shim, brainflow_input_params::BrainFlowInputParamsBuilder, + brainflow_model_params::BrainFlowModelParamsBuilder, data_filter, ml_model, BoardIds, + BrainFlowClassifiers, BrainFlowMetrics, +}; + +fn main() { + brainflow::board_shim::enable_dev_board_logger().unwrap(); + let params = BrainFlowInputParamsBuilder::default().build(); + let board_id = BoardIds::SyntheticBoard as i32; + let board = board_shim::BoardShim::new(board_id, params).unwrap(); + + board.prepare_session().unwrap(); + board.start_stream(45000, "").unwrap(); + thread::sleep(Duration::from_secs(5)); + board.stop_stream().unwrap(); + let data = board.get_board_data(None).unwrap(); + board.release_session().unwrap(); + + let eeg_channels = board_shim::get_eeg_channels(board_id).unwrap(); + let sampling_rate = board_shim::get_sampling_rate(board_id).unwrap(); + let mut bands = + data_filter::get_avg_band_powers(data, eeg_channels, sampling_rate, true).unwrap(); + let mut feature_vector = bands.0; + feature_vector.append(&mut bands.1); + println!("feature_vector: {:?}", feature_vector); + + // calc concentration + let concentration_params = BrainFlowModelParamsBuilder::new() + .metric(BrainFlowMetrics::Concentration) + .classifier(BrainFlowClassifiers::Knn) + .build(); + let concentration = ml_model::MlModel::new(concentration_params).unwrap(); + concentration.prepare().unwrap(); + println!( + "Concentration: {:?}", + concentration.predict(&mut feature_vector) + ); + concentration.release().unwrap(); + + // calc relaxation + let relaxation_params = BrainFlowModelParamsBuilder::new() + .metric(BrainFlowMetrics::Relaxation) + .classifier(BrainFlowClassifiers::Regression) + .build(); + let relaxation = ml_model::MlModel::new(relaxation_params).unwrap(); + relaxation.prepare().unwrap(); + println!("Relaxation: {:?}", relaxation.predict(&mut feature_vector)); + relaxation.release().unwrap(); +} diff --git a/rust-package/brainflow/src/brainflow_model_params.rs b/rust-package/brainflow/src/brainflow_model_params.rs index 0ab7c30d0..5909c219b 100644 --- a/rust-package/brainflow/src/brainflow_model_params.rs +++ b/rust-package/brainflow/src/brainflow_model_params.rs @@ -1,16 +1,43 @@ use getset::Getters; -use serde::{Deserialize, Serialize}; +use serde::{ser::SerializeStruct, Serialize}; + +use crate::{BrainFlowClassifiers, BrainFlowMetrics}; /// Inputs parameters for MlModel. -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Getters)] +#[derive(Debug, Getters)] #[getset(get = "pub", set = "pub")] pub struct BrainFlowModelParams { - metric: usize, - classifier: usize, + metric: BrainFlowMetrics, + classifier: BrainFlowClassifiers, file: String, other_info: String, } +impl Serialize for BrainFlowModelParams { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut state = serializer.serialize_struct("BrainFlowModelParams", 4)?; + state.serialize_field("metric", &(self.metric as usize))?; + state.serialize_field("classifier", &(self.classifier as usize))?; + state.serialize_field("file", &self.file.to_string())?; + state.serialize_field("other_info", &self.other_info.to_string())?; + state.end() + } +} + +impl Default for BrainFlowModelParams { + fn default() -> Self { + Self { + metric: BrainFlowMetrics::Concentration, + classifier: BrainFlowClassifiers::Knn, + file: Default::default(), + other_info: Default::default(), + } + } +} + /// Builder for [BrainFlowModelParams]. #[derive(Default)] pub struct BrainFlowModelParamsBuilder { @@ -24,13 +51,13 @@ impl BrainFlowModelParamsBuilder { } /// Metric to calculate. - pub fn metric(mut self, metric: usize) -> Self { + pub fn metric(mut self, metric: BrainFlowMetrics) -> Self { self.params.metric = metric; self } /// Classifier to use. - pub fn classifier(mut self, classifier: usize) -> Self { + pub fn classifier(mut self, classifier: BrainFlowClassifiers) -> Self { self.params.classifier = classifier; self } diff --git a/rust-package/brainflow/src/data_filter.rs b/rust-package/brainflow/src/data_filter.rs index db60df109..6c5e4af3b 100644 --- a/rust-package/brainflow/src/data_filter.rs +++ b/rust-package/brainflow/src/data_filter.rs @@ -1,7 +1,5 @@ use getset::Getters; -use ndarray::{ - Array1, Array2, Array3, ArrayBase, ArrayView2, AsArray, Ix2, SliceInfo, SliceInfoElem, -}; +use ndarray::{Array1, Array2, Array3, ArrayBase}; use num::Complex; use num_complex::Complex64; use std::os::raw::c_int; @@ -490,33 +488,26 @@ pub fn get_psd_welch( } /// Calculate avg and stddev of BandPowers across all channels, bands are 1-4,4-8,8-13,13-30,30-50. -pub fn get_avg_band_powers<'a, Data>( - data: Data, +pub fn get_avg_band_powers( + data: Array2, eeg_channels: Vec, sampling_rate: usize, apply_filters: bool, -) -> Result<(Vec, Vec)> -where - Data: AsArray<'a, f64, Ix2>, -{ - let data = data.into(); - let data: ArrayView2 = unsafe { - data.slice( - SliceInfo::new( - eeg_channels - .into_iter() - .map(|c| SliceInfoElem::Index(c as isize)) - .collect::>(), - ) - .unwrap(), - ) - }; +) -> Result<(Vec, Vec)> { let shape = data.shape(); - let (rows, cols) = (shape[0], shape[1]); + let (rows, cols) = (eeg_channels.len(), shape[1]); + let mut raw_data = data + .outer_iter() + .enumerate() + .filter(|(i, _)| eeg_channels.contains(i)) + .map(|(_, x)| x) + .flatten() + .copied() + .collect::>(); let mut avg_band_powers = Vec::with_capacity(5); let mut stddev_band_powers = Vec::with_capacity(5); - let mut raw_data: Vec<&f64> = data.iter().collect(); + let res = unsafe { data_handler::get_avg_band_powers( raw_data.as_mut_ptr() as *mut c_double, diff --git a/rust-package/brainflow/src/ffi/board_controller.rs b/rust-package/brainflow/src/ffi/board_controller.rs index 5da6e89d0..195111e3d 100644 --- a/rust-package/brainflow/src/ffi/board_controller.rs +++ b/rust-package/brainflow/src/ffi/board_controller.rs @@ -2,7 +2,6 @@ #![allow(non_camel_case_types)] - extern "C" { pub fn get_board_descr( board_id: ::std::os::raw::c_int, diff --git a/rust-package/brainflow/src/ffi/data_handler.rs b/rust-package/brainflow/src/ffi/data_handler.rs index e6cc0bc90..2bf53fab6 100644 --- a/rust-package/brainflow/src/ffi/data_handler.rs +++ b/rust-package/brainflow/src/ffi/data_handler.rs @@ -2,7 +2,6 @@ #![allow(non_camel_case_types)] - extern "C" { pub fn perform_lowpass( data: *mut f64, diff --git a/rust-package/brainflow/src/ffi/ml_model.rs b/rust-package/brainflow/src/ffi/ml_model.rs index dec005a11..023c7558f 100644 --- a/rust-package/brainflow/src/ffi/ml_model.rs +++ b/rust-package/brainflow/src/ffi/ml_model.rs @@ -2,7 +2,6 @@ #![allow(non_camel_case_types)] - extern "C" { pub fn prepare(json_params: *const ::std::os::raw::c_char) -> ::std::os::raw::c_int; } diff --git a/rust-package/brainflow/src/lib.rs b/rust-package/brainflow/src/lib.rs index ccf7c8f38..d489ba005 100644 --- a/rust-package/brainflow/src/lib.rs +++ b/rust-package/brainflow/src/lib.rs @@ -31,6 +31,8 @@ pub use ffi::constants::BoardIds; pub use ffi::constants::BrainFlowClassifiers; /// Enum to store all Brainflow Exit Codes. pub use ffi::constants::BrainFlowExitCodes; +/// Enum to store BrainFlow metrics +pub use ffi::constants::BrainFlowMetrics; /// Enum to store all supported Detrend Operations. pub use ffi::constants::DetrendOperations; /// Enum to store all supported Filter Types.