diff --git a/ledger/src/shred/merkle.rs b/ledger/src/shred/merkle.rs index 4b6cd792f79471..90b686f5c8f74b 100644 --- a/ledger/src/shred/merkle.rs +++ b/ledger/src/shred/merkle.rs @@ -167,11 +167,15 @@ impl ShredData { let proof_size = self.proof_size()?; let offset = Self::SIZE_OF_HEADERS + Self::capacity(proof_size)?; let size = SIZE_OF_MERKLE_ROOT + usize::from(proof_size) * SIZE_OF_MERKLE_PROOF_ENTRY; - MerkleBranch::try_from( + let merkle_branch = MerkleBranch::try_from( self.payload .get(offset..offset + size) .ok_or(Error::InvalidPayloadSize(self.payload.len()))?, - ) + )?; + if merkle_branch.proof.len() != usize::from(proof_size) { + return Err(Error::InvalidMerkleProof); + } + Ok(merkle_branch) } fn merkle_tree_node(&self) -> Result { @@ -230,13 +234,9 @@ impl ShredData { } fn sanitize(&self, verify_merkle_proof: bool) -> Result<(), Error> { - match self.common_header.shred_variant { - ShredVariant::MerkleData(proof_size) => { - if self.merkle_branch()?.proof.len() != usize::from(proof_size) { - return Err(Error::InvalidProofSize(proof_size)); - } - } - _ => return Err(Error::InvalidShredVariant), + let shred_variant = self.common_header.shred_variant; + if !matches!(shred_variant, ShredVariant::MerkleData(_)) { + return Err(Error::InvalidShredVariant); } if !verify_merkle_proof { debug_assert_matches!(self.verify_merkle_proof(), Ok(true)); @@ -283,11 +283,15 @@ impl ShredCode { let proof_size = self.proof_size()?; let offset = Self::SIZE_OF_HEADERS + Self::capacity(proof_size)?; let size = SIZE_OF_MERKLE_ROOT + usize::from(proof_size) * SIZE_OF_MERKLE_PROOF_ENTRY; - MerkleBranch::try_from( + let merkle_branch = MerkleBranch::try_from( self.payload .get(offset..offset + size) .ok_or(Error::InvalidPayloadSize(self.payload.len()))?, - ) + )?; + if merkle_branch.proof.len() != usize::from(proof_size) { + return Err(Error::InvalidMerkleProof); + } + Ok(merkle_branch) } fn merkle_tree_node(&self) -> Result { @@ -364,13 +368,9 @@ impl ShredCode { } fn sanitize(&self, verify_merkle_proof: bool) -> Result<(), Error> { - match self.common_header.shred_variant { - ShredVariant::MerkleCode(proof_size) => { - if self.merkle_branch()?.proof.len() != usize::from(proof_size) { - return Err(Error::InvalidProofSize(proof_size)); - } - } - _ => return Err(Error::InvalidShredVariant), + let shred_variant = self.common_header.shred_variant; + if !matches!(shred_variant, ShredVariant::MerkleCode(_)) { + return Err(Error::InvalidShredVariant); } if !verify_merkle_proof { debug_assert_matches!(self.verify_merkle_proof(), Ok(true));