Skip to content

Commit 96a10a5

Browse files
committed
Add num_class method
1 parent 0c8bfbd commit 96a10a5

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/booster.rs

+19
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ impl Booster {
140140
Ok(out_result)
141141
}
142142

143+
/// Get number of classes.
144+
pub fn num_class(&self) -> Result<i32> {
145+
let mut num_class = 0;
146+
lgbm_call!(lightgbm_sys::LGBM_BoosterGetNumClasses(
147+
self.handle,
148+
&mut num_class
149+
))?;
150+
151+
Ok(num_class)
152+
}
153+
143154
/// Get Feature Num.
144155
pub fn num_feature(&self) -> Result<i32> {
145156
let mut out_len = 0;
@@ -269,6 +280,14 @@ mod tests {
269280
assert_eq!(num_feature, 28);
270281
}
271282

283+
#[test]
284+
fn num_class() {
285+
let params = _default_params();
286+
let bst = _train_booster(&params);
287+
let num_class = bst.num_class().unwrap();
288+
assert_eq!(num_class, 1);
289+
}
290+
272291
#[test]
273292
fn feature_importance() {
274293
let params = _default_params();

0 commit comments

Comments
 (0)