Skip to content

Commit 6ededa1

Browse files
authored
Merge pull request #3 from postgresml/levkk-predict-single
Use single-dim vector for predictions
2 parents 19c76ec + b4e7778 commit 6ededa1

File tree

1 file changed

+22
-31
lines changed

1 file changed

+22
-31
lines changed

src/booster.rs

+22-31
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,20 @@ impl Booster {
9595
///
9696
/// Input data example
9797
/// ```
98-
/// let data = vec![vec![1.0, 0.1, 0.2],
99-
/// vec![0.7, 0.4, 0.5],
100-
/// vec![0.1, 0.7, 1.0]];
98+
/// let data = vec![1.0, 0.1, 0.2,
99+
/// 0.7, 0.4, 0.5,
100+
/// 0.1, 0.7, 1.0];
101101
/// ```
102102
///
103103
/// Output data example
104104
/// ```
105-
/// let output = vec![vec![1.0, 0.109, 0.433]];
105+
/// let output = vec![1.0, 0.109, 0.433];
106106
/// ```
107-
pub fn predict(&self, data: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>> {
107+
pub fn predict(&self, data: &[f32], num_features: i32) -> Result<Vec<f32>> {
108108
let data_length = data.len();
109-
let feature_length = data[0].len();
109+
let num_rows = data_length / num_features as usize;
110110
let params = CString::new("").unwrap();
111111
let mut out_length: c_longlong = 0;
112-
let flat_data = data.into_iter().flatten().collect::<Vec<_>>();
113112

114113
// get num_class
115114
let mut num_class = 0;
@@ -118,33 +117,24 @@ impl Booster {
118117
&mut num_class
119118
))?;
120119

121-
let out_result: Vec<f64> = vec![Default::default(); data_length * num_class as usize];
120+
let out_result: Vec<f32> = vec![Default::default(); num_rows * num_class as usize];
122121

123122
lgbm_call!(lightgbm_sys::LGBM_BoosterPredictForMat(
124123
self.handle,
125-
flat_data.as_ptr() as *const c_void,
126-
lightgbm_sys::C_API_DTYPE_FLOAT64 as i32,
127-
data_length as i32,
128-
feature_length as i32,
124+
data.as_ptr() as *const c_void,
125+
lightgbm_sys::C_API_DTYPE_FLOAT32 as i32,
126+
num_rows as i32,
127+
num_features,
129128
1_i32,
130-
0_i32,
129+
lightgbm_sys::C_API_PREDICT_NORMAL as i32,
131130
0_i32,
132131
-1_i32,
133132
params.as_ptr() as *const c_char,
134133
&mut out_length,
135-
out_result.as_ptr() as *mut c_double
134+
out_result.as_ptr() as *mut c_double,
136135
))?;
137136

138-
// reshape for multiclass [1,2,3,4,5,6] -> [[1,2,3], [4,5,6]] # 3 class
139-
let reshaped_output = if num_class > 1 {
140-
out_result
141-
.chunks(num_class as usize)
142-
.map(|x| x.to_vec())
143-
.collect()
144-
} else {
145-
vec![out_result]
146-
};
147-
Ok(reshaped_output)
137+
Ok(out_result)
148138
}
149139

150140
/// Get Feature Num.
@@ -257,13 +247,14 @@ mod tests {
257247
}
258248
};
259249
let bst = _train_booster(&params);
260-
let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]];
261-
let result = bst.predict(feature).unwrap();
262-
let mut normalized_result = Vec::new();
263-
for r in &result[0] {
264-
normalized_result.push(if r > &0.5 { 1 } else { 0 });
265-
}
266-
assert_eq!(normalized_result, vec![0, 0, 1]);
250+
let mut features = Vec::new();
251+
features.extend(vec![0.5; 28]);
252+
features.extend(vec![0.0; 28]);
253+
features.extend(vec![0.9; 28]);
254+
255+
assert_eq!(features.len(), 28 * 3);
256+
let result = bst.predict(&features, 28).unwrap();
257+
assert_eq!(result.len(), 3);
267258
}
268259

269260
#[test]

0 commit comments

Comments
 (0)