Skip to content

Commit

Permalink
fix deadlock while using rayon
Browse files Browse the repository at this point in the history
  • Loading branch information
informationsea committed Nov 16, 2023
1 parent bc0c556 commit f919d96
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 5 deletions.
3 changes: 3 additions & 0 deletions bgzip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ pub mod deflate;
/// BGZ header parser
pub mod header;
pub mod index;
#[cfg(feature = "rayon")]
pub(crate) mod rayon;
pub mod read;

pub use deflate::Compression;
/// Tabix file parser. (This module is alpha state.)
pub mod tabix;
Expand Down
24 changes: 24 additions & 0 deletions bgzip/src/rayon.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use std::sync::mpsc::{Receiver, RecvError, RecvTimeoutError, TryRecvError};

const TIMEOUT_DURATION: std::time::Duration = std::time::Duration::from_millis(10);

pub(crate) fn receive_or_yield<R>(receiver: &Receiver<R>) -> std::result::Result<R, RecvError> {
loop {
match receiver.try_recv() {
Ok(t) => return Ok(t),
Err(TryRecvError::Empty) => match rayon::yield_now() {
None => return receiver.recv(),
Some(rayon::Yield::Executed) => continue,
Some(rayon::Yield::Idle) => match receiver.recv_timeout(TIMEOUT_DURATION) {
Ok(t) => return Ok(t),
Err(RecvTimeoutError::Timeout) => {
//dbg!("receive idle");
continue;
}
Err(RecvTimeoutError::Disconnected) => return Err(RecvError),
},
},
Err(TryRecvError::Disconnected) => return Err(RecvError),
}
}
}
42 changes: 39 additions & 3 deletions bgzip/src/read/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::io::{BufRead, Read};
use std::sync::mpsc::{channel, Receiver, Sender};

use crate::deflate::*;
use crate::rayon::receive_or_yield;
use crate::BGZFError;

const EOF_BLOCK: [u8; 10] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0];
Expand Down Expand Up @@ -183,9 +184,7 @@ impl<R: Read> BufRead for BGZFMultiThreadReader<R> {
}

while !self.read_waiting_blocks.contains_key(&self.next_read_index) {
let block = self
.reader_receiver
.recv()
let block = receive_or_yield(&self.reader_receiver)
.expect("reader receive error")
.map_err(|e| -> std::io::Error { e.into() })?;
// eprintln!("fetch: {}", block.index);
Expand Down Expand Up @@ -226,6 +225,43 @@ impl<R: Read> Read for BGZFMultiThreadReader<R> {
#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_many_data() -> anyhow::Result<()> {
let mut expected_reader = flate2::read::MultiGzDecoder::new(std::fs::File::open(
"testfiles/common_all_20180418_half.vcf.gz",
)?);
let mut expected_buf = Vec::new();
expected_reader.read_to_end(&mut expected_buf)?;
const LOOP_NUM: usize = 100;
let expected_buf: &'static [u8] = Box::leak(expected_buf.into_boxed_slice());

let (tx, rx) = channel();
for i in 0..LOOP_NUM {
let tx = tx.clone();
rayon::spawn(move || {
//eprintln!("start");
let mut reader = BGZFMultiThreadReader::new(
std::fs::File::open("testfiles/common_all_20180418_half.vcf.gz").unwrap(),
)
.unwrap();
//eprintln!("open");
let mut read_buf = Vec::new();
reader.read_to_end(&mut read_buf).unwrap();
//eprintln!("end");
assert_eq!(expected_buf.len(), read_buf.len());
assert_eq!(expected_buf, read_buf);
tx.send(i).unwrap();
});
}

for _i in 0..LOOP_NUM {
eprintln!("Finish {} / {}", rx.recv()?, _i);
}

Ok(())
}

#[test]
fn test_thread_read() -> anyhow::Result<()> {
let mut expected_reader = flate2::read::MultiGzDecoder::new(std::fs::File::open(
Expand Down
44 changes: 42 additions & 2 deletions bgzip/src/write/thread.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::index::BGZFIndexEntry;
use crate::rayon::receive_or_yield;
use crate::{deflate::*, index::BGZFIndex, BGZFError};
use std::collections::HashMap;
use std::convert::TryInto;
Expand Down Expand Up @@ -137,8 +138,7 @@ impl<W: Write> BGZFMultiThreadWriter<W> {
let mut current_block = block;
while self.next_compress_index != self.next_write_index {
let next_data = if current_block {
self.writer_receiver
.recv()
receive_or_yield(&self.writer_receiver)
.map_err(|_| Error::new(ErrorKind::Other, "Closed channel"))?
} else {
match self.writer_receiver.try_recv() {
Expand Down Expand Up @@ -286,6 +286,46 @@ mod test {
const WRITE_UNIT: usize = 2000;
const BUF_SIZE: usize = 1000 * 1000 * 10;

#[test]
fn test_write_many() -> anyhow::Result<()> {
let mut reader = flate2::read::MultiGzDecoder::new(std::fs::File::open(
"testfiles/common_all_20180418_half.vcf.gz",
)?);
let mut write_data = Vec::new();
reader.read_to_end(&mut write_data)?;
const LOOP_NUM: usize = 100;
let expected_buf: &'static [u8] = Box::leak(write_data.into_boxed_slice());
let (tx, rx) = channel();

for i in 0..LOOP_NUM {
let tx = tx.clone();
rayon::spawn(move || {
let mut to_write_buf = Vec::new();
let mut writer =
BGZFMultiThreadWriter::new(&mut to_write_buf, Compression::default());
writer.write_all(expected_buf).expect("Failed to write");
writer.flush().expect("Failed to flush");
std::mem::drop(writer);

let mut to_read_buf = Vec::new();
let mut reader = flate2::read::MultiGzDecoder::new(&to_write_buf[..]);
reader
.read_to_end(&mut to_read_buf)
.expect("Failed to read");

assert_eq!(to_read_buf.len(), expected_buf.len());
assert_eq!(to_read_buf, expected_buf);

tx.send(i).unwrap();
});
}

for _i in 0..LOOP_NUM {
eprintln!("Finish {} / {}", rx.recv()?, _i);
}
Ok(())
}

#[test]
fn test_thread_writer() -> anyhow::Result<()> {
let mut rand = rand_pcg::Pcg64Mcg::seed_from_u64(0x9387402456157523);
Expand Down

0 comments on commit f919d96

Please sign in to comment.