diff --git a/core/src/pow/types.rs b/core/src/pow/types.rs index d23f2b704a..cee2a4af24 100644 --- a/core/src/pow/types.rs +++ b/core/src/pow/types.rs @@ -391,21 +391,39 @@ impl Proof { } } -#[inline(always)] +fn extract_bits(bits: &Vec, bit_start: usize, bit_count: usize, read_from: usize) -> u64 { + let mut buf: [u8; 8] = [0; 8]; + buf.copy_from_slice(&bits[read_from..read_from + 8]); + if bit_count == 64 { + return u64::from_le_bytes(buf); + } + let skip_bits = bit_start - read_from * 8; + let bit_mask = (1 << bit_count) - 1; + u64::from_le_bytes(buf) >> skip_bits & bit_mask +} + fn read_number(bits: &Vec, bit_start: usize, bit_count: usize) -> u64 { if bit_count == 0 { return 0; } - let mut buf: [u8; 8] = [0; 8]; - let mut byte_start = bit_start / 8; - if byte_start + 8 > bits.len() { - byte_start = bits.len() - 8; - } - buf.copy_from_slice(&bits[byte_start..byte_start + 8]); - buf.reverse(); - let mut nonce = u64::from_be_bytes(buf); - nonce = nonce << 64 - (bit_start - byte_start * 8) - bit_count; - nonce >> 64 - bit_count + // find where the first byte to read starts + let mut read_from = bit_start / 8; + // move back if we are too close to the end of bits + if read_from + 8 > bits.len() { + read_from = bits.len() - 8; + } + // calculate max bit we can read up to (+64 bits from the start) + let max_bit_end = (read_from + 8) * 8; + // calculate max bit we want to read + let max_pos = bit_start + bit_count; + // check if we can read it all at once + if max_pos <= max_bit_end { + extract_bits(bits, bit_start, bit_count, read_from) + } else { + let low = extract_bits(bits, bit_start, 8, read_from); + let high = extract_bits(bits, bit_start + 8, bit_count - 8, read_from + 1); + (high << 8) + low + } } impl Readable for Proof { @@ -420,6 +438,9 @@ impl Readable for Proof { let nonce_bits = edge_bits as usize; let bits_len = nonce_bits * global::proofsize(); let bytes_len = BitVec::bytes_len(bits_len); + if bytes_len < 8 { + return Err(ser::Error::CorruptedData); + } let bits = reader.read_fixed_bytes(bytes_len)?; for n in 0..global::proofsize() { @@ -478,3 +499,47 @@ impl BitVec { self.bits[pos / 8] |= 1 << (pos % 8) as u8; } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ser::{BinReader, BinWriter, ProtocolVersion}; + use rand::Rng; + use std::io::Cursor; + + #[test] + fn test_proof_rw() { + for edge_bits in 10..64 { + let mut proof = Proof::new(gen_proof(edge_bits as u32)); + proof.edge_bits = edge_bits; + let mut buf = Cursor::new(Vec::new()); + let mut w = BinWriter::new(&mut buf, ProtocolVersion::local()); + if let Err(e) = proof.write(&mut w) { + panic!("failed to write proof {:?}", e); + } + buf.set_position(0); + let mut r = BinReader::new(&mut buf, ProtocolVersion::local()); + match Proof::read(&mut r) { + Err(e) => panic!("failed to read proof: {:?}", e), + Ok(p) => assert_eq!(p, proof), + } + } + } + + fn gen_proof(bits: u32) -> Vec { + let mut rng = rand::thread_rng(); + let mut v = Vec::with_capacity(42); + for _ in 0..42 { + v.push(rng.gen_range( + u64::pow(2, bits - 1), + if bits == 64 { + std::u64::MAX + } else { + u64::pow(2, bits) + }, + )) + } + v + } + +}