Skip to content

Commit

Permalink
Pass the axis and axis index when create a tree, to allow for paralle…
Browse files Browse the repository at this point in the history
…l nmt caches (celestiaorg#119)
  • Loading branch information
musalbas committed Sep 15, 2022
1 parent fa3361f commit 4b3ad26
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions datasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (ds *dataSquare) getRowRoot(x uint) []byte {
return ds.rowRoots[x]
}

tree := ds.createTreeFn()
tree := ds.createTreeFn(Row, x)
for i, d := range ds.row(x) {
tree.Push(d, SquareIndex{Cell: uint(i), Axis: x})
}
Expand All @@ -245,7 +245,7 @@ func (ds *dataSquare) getColRoot(y uint) []byte {
return ds.colRoots[y]
}

tree := ds.createTreeFn()
tree := ds.createTreeFn(Col, y)
for i, d := range ds.col(y) {
tree.Push(d, SquareIndex{Axis: y, Cell: uint(i)})
}
Expand Down
4 changes: 2 additions & 2 deletions datasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func BenchmarkRoots(b *testing.B) {
}

func computeRowProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, uint, error) {
tree := ds.createTreeFn()
tree := ds.createTreeFn(Row, x)
data := ds.row(x)

for i := uint(0); i < ds.width; i++ {
Expand All @@ -226,7 +226,7 @@ func computeRowProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, ui
}

func computeColProof(ds *dataSquare, x uint, y uint) ([]byte, [][]byte, uint, uint, error) {
tree := ds.createTreeFn()
tree := ds.createTreeFn(Col, y)
data := ds.col(y)

for i := uint(0); i < ds.width; i++ {
Expand Down
8 changes: 4 additions & 4 deletions extendeddatacrossword.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func (eds *ExtendedDataSquare) verifyAgainstRowRoots(
r uint,
shares [][]byte,
) error {
root := eds.computeSharesRoot(shares, r)
root := eds.computeSharesRoot(shares, Row, r)

if !bytes.Equal(root, rowRoots[r]) {
return &ErrByzantineData{Row, r, shares}
Expand All @@ -283,7 +283,7 @@ func (eds *ExtendedDataSquare) verifyAgainstColRoots(
c uint,
shares [][]byte,
) error {
root := eds.computeSharesRoot(shares, c)
root := eds.computeSharesRoot(shares, Col, c)

if !bytes.Equal(root, colRoots[c]) {
return &ErrByzantineData{Col, c, shares}
Expand Down Expand Up @@ -349,8 +349,8 @@ func noMissingData(input [][]byte) bool {
return true
}

func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, i uint) []byte {
tree := eds.createTreeFn()
func (eds *ExtendedDataSquare) computeSharesRoot(shares [][]byte, axis Axis, i uint) []byte {
tree := eds.createTreeFn(axis, i)
for cell, d := range shares {
tree.Push(d, SquareIndex{Cell: uint(cell), Axis: i})
}
Expand Down
2 changes: 1 addition & 1 deletion extendeddatacrossword_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestRepairExtendedDataSquare(t *testing.T) {
if err != nil {
t.Errorf("could not decode fraud proof shares; got: %v", err)
}
root := corrupted.computeSharesRoot(rebuiltShares, fraudProof.Index)
root := corrupted.computeSharesRoot(rebuiltShares, byzData.Axis, fraudProof.Index)
if bytes.Equal(root, corrupted.getRowRoot(fraudProof.Index)) {
// If the roots match, then the fraud proof should be for invalid erasure coding.
parityShares, err := codec.Encode(rebuiltShares[0:corrupted.originalDataWidth])
Expand Down
4 changes: 2 additions & 2 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

// TreeConstructorFn creates a fresh Tree instance to be used as the Merkle inside of rsmt2d.
type TreeConstructorFn = func() Tree
type TreeConstructorFn = func(axis Axis, index uint) Tree

// SquareIndex contains all information needed to identify the cell that is being
// pushed
Expand All @@ -29,7 +29,7 @@ type DefaultTree struct {
root []byte
}

func NewDefaultTree() Tree {
func NewDefaultTree(axis Axis, index uint) Tree {
return &DefaultTree{
Tree: merkletree.New(sha256.New()),
leaves: make([][]byte, 0, 128),
Expand Down

0 comments on commit 4b3ad26

Please sign in to comment.