Skip to content

Commit 19c76ec

Browse files
authored
Merge pull request #2 from postgresml/levkk-from-vec
Create datasets from one-dim vec
2 parents c40fab6 + 1434e70 commit 19c76ec

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

src/dataset.rs

+58
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,51 @@ impl Dataset {
4343
Self { handle }
4444
}
4545

46+
/// Create a new `Dataset` from a dense array in row-major order
47+
/// without allocating rows in memory.
48+
///
49+
/// Example
50+
/// ```
51+
/// use lightgbm::Dataset;
52+
///
53+
/// let data = vec![1.0, 0.1, 0.2, 0.1,
54+
/// 0.7, 0.4, 0.5, 0.1,
55+
/// 0.9, 0.8, 0.5, 0.1,
56+
/// 0.2, 0.2, 0.8, 0.7,
57+
/// 0.1, 0.7, 1.0, 0.9];
58+
/// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0];
59+
/// let dataset = Dataset::from_vec(&data, &label, 4).unwrap();
60+
/// ```
61+
pub fn from_vec(data: &[f32], labels: &[f32], num_features: i32) -> Result<Self> {
62+
let data_length = data.len() as i32;
63+
let feature_length = num_features;
64+
let params = CString::new("").unwrap();
65+
let label_str = CString::new("label").unwrap();
66+
let reference = std::ptr::null_mut(); // not use
67+
let mut handle = std::ptr::null_mut();
68+
69+
lgbm_call!(lightgbm_sys::LGBM_DatasetCreateFromMat(
70+
data.as_ptr() as *const c_void,
71+
lightgbm_sys::C_API_DTYPE_FLOAT32 as i32,
72+
data_length,
73+
feature_length,
74+
1_i32,
75+
params.as_ptr() as *const c_char,
76+
reference,
77+
&mut handle
78+
))?;
79+
80+
lgbm_call!(lightgbm_sys::LGBM_DatasetSetField(
81+
handle,
82+
label_str.as_ptr() as *const c_char,
83+
labels.as_ptr() as *const c_void,
84+
data_length as i32,
85+
lightgbm_sys::C_API_DTYPE_FLOAT32 as i32
86+
))?;
87+
88+
Ok(Self::new(handle))
89+
}
90+
4691
/// Create a new `Dataset` from dense array in row-major order.
4792
///
4893
/// Example
@@ -225,6 +270,19 @@ mod tests {
225270
assert!(dataset.is_ok());
226271
}
227272

273+
#[test]
274+
fn from_vec() {
275+
let data = vec![
276+
1.0, 0.1, 0.2, 0.1, 0.7, 0.4, 0.5, 0.1, 0.9, 0.8, 0.5, 0.1, 0.2, 0.2, 0.8, 0.7, 0.1,
277+
0.7, 1.0, 0.9,
278+
];
279+
280+
let labels = vec![0.0, 0.0, 0.0, 1.0, 1.0];
281+
282+
let dataset = Dataset::from_vec(&data, &labels, 4);
283+
assert!(dataset.is_ok());
284+
}
285+
228286
#[cfg(feature = "dataframe")]
229287
#[test]
230288
fn from_dataframe() {

0 commit comments

Comments
 (0)