Skip to content

Commit

Permalink
libdf: Get rid of nightly via more unsafe code
Browse files Browse the repository at this point in the history
  • Loading branch information
Rikorose committed Jun 13, 2023
1 parent 85a6402 commit df191c4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/rust_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
matrix:
include:
- {command: fmt, rust: nightly, args: '--all -- --check', hdf5: false}
- {command: clippy, rust: nightly, args: '--all-features --tests -- -D warnings', hdf5: true, alsa: true}
- {command: test, rust: nightly, args: '--all-features -p deep_filter', hdf5: true}
- {command: clippy, rust: stable, args: '--all-features --tests -- -D warnings', hdf5: true, alsa: true}
- {command: test, rust: stable, args: '--all-features -p deep_filter', hdf5: true}
- {command: build, rust: stable, args: '-p deep_filter', hdf5: false}
- {command: build, rust: stable, args: '-p DeepFilterLib', hdf5: false}
- {command: build, rust: stable, args: '-p DeepFilterDataLoader', hdf5: true}
Expand Down
1 change: 1 addition & 0 deletions libDF/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ tract = [
]
default-model = [] # Include default DFN3 model
default-model-ll = [] # Include default DFN3 low latency model
nightly-features = []
capi = ["tract", "default-model", "dep:ndarray"]
hdf5-static = ["hdf5?/static"]

Expand Down
1 change: 0 additions & 1 deletion libDF/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![allow(dead_code)]
#![cfg_attr(feature = "tract", feature(get_mut_unchecked))]

use std::ops::MulAssign;
use std::sync::Arc;
Expand Down
15 changes: 9 additions & 6 deletions libDF/src/tract.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::fs::File;
use std::io::{Cursor, Read};
use std::path::{Path, PathBuf};
use std::rc::Rc;
#[cfg(feature = "timings")]
use std::time::Instant;

Expand Down Expand Up @@ -382,8 +381,8 @@ impl DfTract {

for (nsy_ch, mut erb_ch, mut cplx_ch, state) in izip!(
spec.axis_iter(Axis(0)),
tvalue_as_mut(&mut self.erb_buf).to_array_view_mut()?.axis_iter_mut(Axis(0)),
tvalue_as_mut(&mut self.cplx_buf).to_array_view_mut()?.axis_iter_mut(Axis(0)),
tvalue_to_array_view_mut(&mut self.erb_buf).axis_iter_mut(Axis(0)),
tvalue_to_array_view_mut(&mut self.cplx_buf).axis_iter_mut(Axis(0)),
self.df_states.iter_mut()
) {
let nsy_ch = as_slice_complex(nsy_ch.as_slice().unwrap());
Expand Down Expand Up @@ -965,11 +964,15 @@ pub fn as_array_mut_complex<'a>(
ArrayViewMutD::from_shape_ptr(shape, ptr)
}
}
pub fn tvalue_as_mut(x: &mut TValue) -> &mut Tensor {
pub fn tvalue_to_array_view_mut(x: &mut TValue) -> ArrayViewMutD<f32> {
unsafe {
match x {
TValue::Var(x) => Rc::get_mut_unchecked(x),
TValue::Const(x) => Arc::get_mut_unchecked(x),
TValue::Var(x) => {
ArrayViewMutD::from_shape_ptr(x.shape(), x.as_ptr_unchecked::<f32>() as *mut f32)
}
TValue::Const(x) => {
ArrayViewMutD::from_shape_ptr(x.shape(), x.as_ptr_unchecked::<f32>() as *mut f32)
}
}
}
}

0 comments on commit df191c4

Please sign in to comment.