Skip to content

Commit

Permalink
Merge pull request input-output-hk#90 from ApexTheory/86-batch-merkle…
Browse files Browse the repository at this point in the history
…-proof-ser-de

Batch Merkle Proof Serialization & Deserialization
  • Loading branch information
kushti authored Nov 12, 2021
2 parents f260a09 + 1753dff commit f08152c
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package scorex.crypto.authds.merkle.serialization

import com.google.common.primitives.{Bytes, Ints}
import scorex.crypto.authds.merkle.BatchMerkleProof
import scorex.crypto.authds.{EmptyByteArray, Side}
import scorex.crypto.hash.{CryptographicHash, Digest, Digest32}

import scala.util.Try

class BatchMerkleProofSerializer[D <: Digest32, HF <: CryptographicHash[D]](implicit val hf: HF) {

private val digestSize = hf.DigestSize
private val indexSize = 4
private val sideSize = 1
private val indicesSize = digestSize + indexSize
private val proofsSize = digestSize + sideSize

def serialize(bmp: BatchMerkleProof[D]): Array[Byte] =
Bytes.concat(
Ints.toByteArray(bmp.indices.size),
Ints.toByteArray(bmp.proofs.size),
indicesToBytes(bmp.indices),
proofsToBytes(bmp.proofs)
)

def deserialize(bytes: Array[Byte]): Try[BatchMerkleProof[D]] = Try {

if (bytes.length < 8) {
throw new Error("Deserialization error, empty input.")
}

val numIndices = Ints.fromByteArray(bytes.slice(0, 4))
val numProofs = Ints.fromByteArray(bytes.slice(4, 8))
val (indices, proofs) = bytes.drop(8).splitAt(numIndices * indicesSize)

if (indices.length != numIndices * indicesSize || proofs.length != numProofs * proofsSize) {
throw new Error("Deserialization error, invalid input.")
}

BatchMerkleProof(
indicesFromBytes(indices),
proofsFromBytes(proofs)
)
}

private[serialization] def indicesToBytes(indices: Seq[(Int, Digest)]): Array[Byte] = {
Bytes.concat(
indices.map(i => (Ints.toByteArray(i._1), i._2)).flatten{case (a, b) => Bytes.concat(a, b)}.toArray
)
}

private[serialization] def proofsToBytes(proofs: Seq[(Digest, Side)]): Array[Byte] = {
Bytes.concat(
proofs.map(p => (p._1, Array(p._2.toByte))).flatten{
case (a, b) if a.isEmpty => Bytes.concat(Array.ofDim[Byte](32), b)
case (a, b) => Bytes.concat(a, b)
}.toArray
)
}

private[serialization] def indicesFromBytes(bytes: Array[Byte]): Seq[(Int, Digest)] = {
bytes.grouped(indicesSize)
.map(b => {
val index = Ints.fromByteArray(b.slice(0, indexSize))
val hash = b.slice(indexSize, indicesSize).asInstanceOf[Digest]
(index,hash)
})
.toSeq
}

private[serialization] def proofsFromBytes(bytes: Array[Byte]): Seq[(Digest, Side)] = {
bytes.grouped(proofsSize)
.map(b => {
val hashBytes = b.slice(0, digestSize)
val hash = (if (hashBytes.forall(0.toByte.equals)) EmptyByteArray else hashBytes).asInstanceOf[Digest]
val side = b.apply(digestSize).asInstanceOf[Side]
(hash, side)
})
.toSeq
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package scorex.crypto.authds.merkle.serialization

import org.scalatest.TryValues
import org.scalatest.propspec.AnyPropSpec
import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
import scorex.crypto.authds.merkle.{BatchMerkleProof, Leaf, MerkleTree}
import scorex.crypto.authds.{LeafData, Side, TwoPartyTests}
import scorex.crypto.hash.{Digest, Digest32, Keccak256}

import scala.util.Random

class BatchMerkleProofSerializerSpecification extends AnyPropSpec
with ScalaCheckDrivenPropertyChecks
with TwoPartyTests
with TryValues {

type D = Digest32
type HF = Keccak256.type
implicit val hf: HF = Keccak256
private val LeafSize = 32

property("Batch proof serialization + deserialization") {
val r = new Random()
val serializer = new BatchMerkleProofSerializer[D, HF]
forAll(smallInt) { N: Int =>
whenever(N > 0) {
val d = (0 until N).map(_ => LeafData @@ scorex.utils.Random.randomBytes(LeafSize))
val tree = MerkleTree(d)
val randIndices = (0 until r.nextInt(N + 1) + 1)
.map(_ => r.nextInt(N))
.distinct
.sorted

val compactMultiproof = tree.proofByIndices(randIndices).get
val serializedBytes = serializer.serialize(compactMultiproof)
val rebuiltMultiproof = serializer.deserialize(serializedBytes).get

serializedBytes.length shouldEqual
(8 + (compactMultiproof.proofs.size * 33) + (compactMultiproof.indices.size * 36))
compactMultiproof.indices.zipWithIndex.foreach { case ((index, hash), i) =>
val res = rebuiltMultiproof.indices.apply(i)
index shouldEqual res._1
hash shouldEqual res._2
}
compactMultiproof.proofs.zipWithIndex.foreach { case ((digest, side), i) =>
val res = rebuiltMultiproof.proofs.apply(i)
digest shouldEqual res._1
side shouldEqual res._2
}
}
}
}

property(testName = "empty deserialization input") {
val serializer = new BatchMerkleProofSerializer[D, HF]
val res = serializer.deserialize(scorex.utils.Random.randomBytes(2))
res.failure.exception should have message "Deserialization error, empty input."
}

property(testName = "invalid deserialization input") {
val serializer = new BatchMerkleProofSerializer[D, HF]
val res = serializer.deserialize(scorex.utils.Random.randomBytes(9))
res.failure.exception should have message "Deserialization error, invalid input."
}

property("indices serialization + deserialization") {
val r = new Random()
val serializer = new BatchMerkleProofSerializer[D, HF]
forAll(smallInt) { N: Int =>
whenever(N > 0) {

val d = (0 until N).map(_ => LeafData @@ scorex.utils.Random.randomBytes(LeafSize))
val randIndices = (0 until r.nextInt(N + 1) + 1)
.map(_ => r.nextInt(N))
.sorted
.distinct
val indices = randIndices zip randIndices.map(i => Leaf(d.apply(i)).hash)

val serializedIndices: Array[Byte] = serializer.indicesToBytes(indices)
val deserializedIndices: Seq[(Int, Digest)] = serializer.indicesFromBytes(serializedIndices)

indices.zipWithIndex.foreach { case ((index, hash), i) =>
val res = deserializedIndices.apply(i)
index shouldEqual res._1
hash shouldEqual res._2
}
}
}
}

property("proofs serialization + deserialization") {
val r = new Random()
val serializer = new BatchMerkleProofSerializer[D, HF]
forAll(smallInt) { N: Int =>
whenever(N > 0) {

val proofs: Seq[(Digest, Side)] = (0 until N)
.map(_ => LeafData @@ scorex.utils.Random.randomBytes(LeafSize))
.map(l => (Leaf(l).hash, Side @@ r.nextInt(2).toByte))

val serializedProofs: Array[Byte] = serializer.proofsToBytes(proofs)
val deserializedProofs: Seq[(Digest, Side)] = serializer.proofsFromBytes(serializedProofs)

proofs.zipWithIndex.foreach { case ((digest, side), i) =>
val res = deserializedProofs.apply(i)
digest shouldEqual res._1
side shouldEqual res._2
}
}
}
}
}

0 comments on commit f08152c

Please sign in to comment.