Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eeg metrics example for Rust #383

Merged
merged 3 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/run_windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ jobs:
- name: EEG Metrics CI C# Test
run: .\csharp-package\brainflow\tests\eeg_metrics\bin\Release\test.exe
shell: cmd
- name: EEG Metrics CI Rust Test
run: |
cd %GITHUB_WORKSPACE%\rust-package\brainflow
cargo run --example=eeg_metrics
shell: cmd
# Start Deploy Stage
- name: Install Python AWS Tools
run: |
Expand Down
52 changes: 52 additions & 0 deletions rust-package/brainflow/examples/eeg_metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::{thread, time::Duration};

use brainflow::{
board_shim, brainflow_input_params::BrainFlowInputParamsBuilder,
brainflow_model_params::BrainFlowModelParamsBuilder, data_filter, ml_model, BoardIds,
BrainFlowClassifiers, BrainFlowMetrics,
};

fn main() {
brainflow::board_shim::enable_dev_board_logger().unwrap();
let params = BrainFlowInputParamsBuilder::default().build();
let board_id = BoardIds::SyntheticBoard as i32;
let board = board_shim::BoardShim::new(board_id, params).unwrap();

board.prepare_session().unwrap();
board.start_stream(45000, "").unwrap();
thread::sleep(Duration::from_secs(5));
board.stop_stream().unwrap();
let data = board.get_board_data(None).unwrap();
board.release_session().unwrap();

let eeg_channels = board_shim::get_eeg_channels(board_id).unwrap();
let sampling_rate = board_shim::get_sampling_rate(board_id).unwrap();
let mut bands =
data_filter::get_avg_band_powers(data, eeg_channels, sampling_rate, true).unwrap();
let mut feature_vector = bands.0;
feature_vector.append(&mut bands.1);
println!("feature_vector: {:?}", feature_vector);

// calc concentration
let concentration_params = BrainFlowModelParamsBuilder::new()
.metric(BrainFlowMetrics::Concentration)
.classifier(BrainFlowClassifiers::Knn)
.build();
let concentration = ml_model::MlModel::new(concentration_params).unwrap();
concentration.prepare().unwrap();
println!(
"Concentration: {:?}",
concentration.predict(&mut feature_vector)
);
concentration.release().unwrap();

// calc relaxation
let relaxation_params = BrainFlowModelParamsBuilder::new()
.metric(BrainFlowMetrics::Relaxation)
.classifier(BrainFlowClassifiers::Regression)
.build();
let relaxation = ml_model::MlModel::new(relaxation_params).unwrap();
relaxation.prepare().unwrap();
println!("Relaxation: {:?}", relaxation.predict(&mut feature_vector));
relaxation.release().unwrap();
}
39 changes: 33 additions & 6 deletions rust-package/brainflow/src/brainflow_model_params.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
use getset::Getters;
use serde::{Deserialize, Serialize};
use serde::{ser::SerializeStruct, Serialize};

use crate::{BrainFlowClassifiers, BrainFlowMetrics};

/// Inputs parameters for MlModel.
#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Getters)]
#[derive(Debug, Getters)]
#[getset(get = "pub", set = "pub")]
pub struct BrainFlowModelParams {
metric: usize,
classifier: usize,
metric: BrainFlowMetrics,
classifier: BrainFlowClassifiers,
file: String,
other_info: String,
}

impl Serialize for BrainFlowModelParams {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("BrainFlowModelParams", 4)?;
state.serialize_field("metric", &(self.metric as usize))?;
state.serialize_field("classifier", &(self.classifier as usize))?;
state.serialize_field("file", &self.file.to_string())?;
state.serialize_field("other_info", &self.other_info.to_string())?;
state.end()
}
}

impl Default for BrainFlowModelParams {
fn default() -> Self {
Self {
metric: BrainFlowMetrics::Concentration,
classifier: BrainFlowClassifiers::Knn,
file: Default::default(),
other_info: Default::default(),
}
}
}

/// Builder for [BrainFlowModelParams].
#[derive(Default)]
pub struct BrainFlowModelParamsBuilder {
Expand All @@ -24,13 +51,13 @@ impl BrainFlowModelParamsBuilder {
}

/// Metric to calculate.
pub fn metric(mut self, metric: usize) -> Self {
pub fn metric(mut self, metric: BrainFlowMetrics) -> Self {
self.params.metric = metric;
self
}

/// Classifier to use.
pub fn classifier(mut self, classifier: usize) -> Self {
pub fn classifier(mut self, classifier: BrainFlowClassifiers) -> Self {
self.params.classifier = classifier;
self
}
Expand Down
37 changes: 14 additions & 23 deletions rust-package/brainflow/src/data_filter.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use getset::Getters;
use ndarray::{
Array1, Array2, Array3, ArrayBase, ArrayView2, AsArray, Ix2, SliceInfo, SliceInfoElem,
};
use ndarray::{Array1, Array2, Array3, ArrayBase};
use num::Complex;
use num_complex::Complex64;
use std::os::raw::c_int;
Expand Down Expand Up @@ -490,33 +488,26 @@ pub fn get_psd_welch(
}

/// Calculate avg and stddev of BandPowers across all channels, bands are 1-4,4-8,8-13,13-30,30-50.
pub fn get_avg_band_powers<'a, Data>(
data: Data,
pub fn get_avg_band_powers(
data: Array2<f64>,
eeg_channels: Vec<usize>,
sampling_rate: usize,
apply_filters: bool,
) -> Result<(Vec<f64>, Vec<f64>)>
where
Data: AsArray<'a, f64, Ix2>,
{
let data = data.into();
let data: ArrayView2<f64> = unsafe {
data.slice(
SliceInfo::new(
eeg_channels
.into_iter()
.map(|c| SliceInfoElem::Index(c as isize))
.collect::<Vec<SliceInfoElem>>(),
)
.unwrap(),
)
};
) -> Result<(Vec<f64>, Vec<f64>)> {
let shape = data.shape();
let (rows, cols) = (shape[0], shape[1]);
let (rows, cols) = (eeg_channels.len(), shape[1]);
let mut raw_data = data
.outer_iter()
.enumerate()
.filter(|(i, _)| eeg_channels.contains(i))
.map(|(_, x)| x)
.flatten()
.copied()
.collect::<Vec<f64>>();

let mut avg_band_powers = Vec::with_capacity(5);
let mut stddev_band_powers = Vec::with_capacity(5);
let mut raw_data: Vec<&f64> = data.iter().collect();

let res = unsafe {
data_handler::get_avg_band_powers(
raw_data.as_mut_ptr() as *mut c_double,
Expand Down
1 change: 0 additions & 1 deletion rust-package/brainflow/src/ffi/board_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#![allow(non_camel_case_types)]


extern "C" {
pub fn get_board_descr(
board_id: ::std::os::raw::c_int,
Expand Down
1 change: 0 additions & 1 deletion rust-package/brainflow/src/ffi/data_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#![allow(non_camel_case_types)]


extern "C" {
pub fn perform_lowpass(
data: *mut f64,
Expand Down
1 change: 0 additions & 1 deletion rust-package/brainflow/src/ffi/ml_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#![allow(non_camel_case_types)]


extern "C" {
pub fn prepare(json_params: *const ::std::os::raw::c_char) -> ::std::os::raw::c_int;
}
Expand Down
2 changes: 2 additions & 0 deletions rust-package/brainflow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub use ffi::constants::BoardIds;
pub use ffi::constants::BrainFlowClassifiers;
/// Enum to store all Brainflow Exit Codes.
pub use ffi::constants::BrainFlowExitCodes;
/// Enum to store BrainFlow metrics
pub use ffi::constants::BrainFlowMetrics;
/// Enum to store all supported Detrend Operations.
pub use ffi::constants::DetrendOperations;
/// Enum to store all supported Filter Types.
Expand Down