@@ -95,21 +95,20 @@ impl Booster {
95
95
///
96
96
/// Input data example
97
97
/// ```
98
- /// let data = vec![vec![ 1.0, 0.1, 0.2] ,
99
- /// vec![ 0.7, 0.4, 0.5] ,
100
- /// vec![ 0.1, 0.7, 1.0] ];
98
+ /// let data = vec![1.0, 0.1, 0.2,
99
+ /// 0.7, 0.4, 0.5,
100
+ /// 0.1, 0.7, 1.0];
101
101
/// ```
102
102
///
103
103
/// Output data example
104
104
/// ```
105
- /// let output = vec![vec![ 1.0, 0.109, 0.433] ];
105
+ /// let output = vec![1.0, 0.109, 0.433];
106
106
/// ```
107
- pub fn predict ( & self , data : Vec < Vec < f64 > > ) -> Result < Vec < Vec < f64 > > > {
107
+ pub fn predict ( & self , data : & [ f32 ] , num_features : i32 ) -> Result < Vec < f32 > > {
108
108
let data_length = data. len ( ) ;
109
- let feature_length = data [ 0 ] . len ( ) ;
109
+ let num_rows = data_length / num_features as usize ;
110
110
let params = CString :: new ( "" ) . unwrap ( ) ;
111
111
let mut out_length: c_longlong = 0 ;
112
- let flat_data = data. into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
113
112
114
113
// get num_class
115
114
let mut num_class = 0 ;
@@ -118,33 +117,24 @@ impl Booster {
118
117
& mut num_class
119
118
) ) ?;
120
119
121
- let out_result: Vec < f64 > = vec ! [ Default :: default ( ) ; data_length * num_class as usize ] ;
120
+ let out_result: Vec < f32 > = vec ! [ Default :: default ( ) ; num_rows * num_class as usize ] ;
122
121
123
122
lgbm_call ! ( lightgbm_sys:: LGBM_BoosterPredictForMat (
124
123
self . handle,
125
- flat_data . as_ptr( ) as * const c_void,
126
- lightgbm_sys:: C_API_DTYPE_FLOAT64 as i32 ,
127
- data_length as i32 ,
128
- feature_length as i32 ,
124
+ data . as_ptr( ) as * const c_void,
125
+ lightgbm_sys:: C_API_DTYPE_FLOAT32 as i32 ,
126
+ num_rows as i32 ,
127
+ num_features ,
129
128
1_i32 ,
130
- 0_i32 ,
129
+ lightgbm_sys :: C_API_PREDICT_NORMAL as i32 ,
131
130
0_i32 ,
132
131
-1_i32 ,
133
132
params. as_ptr( ) as * const c_char,
134
133
& mut out_length,
135
- out_result. as_ptr( ) as * mut c_double
134
+ out_result. as_ptr( ) as * mut c_double,
136
135
) ) ?;
137
136
138
- // reshape for multiclass [1,2,3,4,5,6] -> [[1,2,3], [4,5,6]] # 3 class
139
- let reshaped_output = if num_class > 1 {
140
- out_result
141
- . chunks ( num_class as usize )
142
- . map ( |x| x. to_vec ( ) )
143
- . collect ( )
144
- } else {
145
- vec ! [ out_result]
146
- } ;
147
- Ok ( reshaped_output)
137
+ Ok ( out_result)
148
138
}
149
139
150
140
/// Get Feature Num.
@@ -257,13 +247,14 @@ mod tests {
257
247
}
258
248
} ;
259
249
let bst = _train_booster ( & params) ;
260
- let feature = vec ! [ vec![ 0.5 ; 28 ] , vec![ 0.0 ; 28 ] , vec![ 0.9 ; 28 ] ] ;
261
- let result = bst. predict ( feature) . unwrap ( ) ;
262
- let mut normalized_result = Vec :: new ( ) ;
263
- for r in & result[ 0 ] {
264
- normalized_result. push ( if r > & 0.5 { 1 } else { 0 } ) ;
265
- }
266
- assert_eq ! ( normalized_result, vec![ 0 , 0 , 1 ] ) ;
250
+ let mut features = Vec :: new ( ) ;
251
+ features. extend ( vec ! [ 0.5 ; 28 ] ) ;
252
+ features. extend ( vec ! [ 0.0 ; 28 ] ) ;
253
+ features. extend ( vec ! [ 0.9 ; 28 ] ) ;
254
+
255
+ assert_eq ! ( features. len( ) , 28 * 3 ) ;
256
+ let result = bst. predict ( & features, 28 ) . unwrap ( ) ;
257
+ assert_eq ! ( result. len( ) , 3 ) ;
267
258
}
268
259
269
260
#[ test]
0 commit comments