Skip to content

Commit

Permalink
trie: ensure resolved nodes stay loaded
Browse files Browse the repository at this point in the history
Commit 40cdcf1 broke the optimisation which kept nodes resolved
during Get in the trie. The decoder assigned cache generation 0
unconditionally, causing resolved nodes to get flushed on Commit.

This commit fixes it and adds two tests.
  • Loading branch information
fjl committed Oct 18, 2016
1 parent 187d6a6 commit 177cab5
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 43 deletions.
2 changes: 1 addition & 1 deletion trie/hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
return hash, n, nil
}
if n.canUnload(h.cachegen, h.cachelimit) {
// Evict the node from cache. All of its subnodes will have a lower or equal
// Unload the node from cache. All of its subnodes will have a lower or equal
// cache generation number.
return hash, hash, nil
}
Expand Down
26 changes: 13 additions & 13 deletions trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ func (n valueNode) fstring(ind string) string {
return fmt.Sprintf("%x ", []byte(n))
}

func mustDecodeNode(hash, buf []byte) node {
n, err := decodeNode(hash, buf)
func mustDecodeNode(hash, buf []byte, cachegen uint16) node {
n, err := decodeNode(hash, buf, cachegen)
if err != nil {
panic(fmt.Sprintf("node %x: %v", hash, err))
}
return n
}

// decodeNode parses the RLP encoding of a trie node.
func decodeNode(hash, buf []byte) (node, error) {
func decodeNode(hash, buf []byte, cachegen uint16) (node, error) {
if len(buf) == 0 {
return nil, io.ErrUnexpectedEOF
}
Expand All @@ -123,22 +123,22 @@ func decodeNode(hash, buf []byte) (node, error) {
}
switch c, _ := rlp.CountValues(elems); c {
case 2:
n, err := decodeShort(hash, buf, elems)
n, err := decodeShort(hash, buf, elems, cachegen)
return n, wrapError(err, "short")
case 17:
n, err := decodeFull(hash, buf, elems)
n, err := decodeFull(hash, buf, elems, cachegen)
return n, wrapError(err, "full")
default:
return nil, fmt.Errorf("invalid number of list elements: %v", c)
}
}

func decodeShort(hash, buf, elems []byte) (node, error) {
func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
kbuf, rest, err := rlp.SplitString(elems)
if err != nil {
return nil, err
}
flag := nodeFlag{hash: hash}
flag := nodeFlag{hash: hash, gen: cachegen}
key := compactDecode(kbuf)
if key[len(key)-1] == 16 {
// value node
Expand All @@ -148,17 +148,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
}
return &shortNode{key, append(valueNode{}, val...), flag}, nil
}
r, _, err := decodeRef(rest)
r, _, err := decodeRef(rest, cachegen)
if err != nil {
return nil, wrapError(err, "val")
}
return &shortNode{key, r, flag}, nil
}

func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
n := &fullNode{flags: nodeFlag{hash: hash}}
func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) {
n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}}
for i := 0; i < 16; i++ {
cld, rest, err := decodeRef(elems)
cld, rest, err := decodeRef(elems, cachegen)
if err != nil {
return n, wrapError(err, fmt.Sprintf("[%d]", i))
}
Expand All @@ -176,7 +176,7 @@ func decodeFull(hash, buf, elems []byte) (*fullNode, error) {

const hashLen = len(common.Hash{})

func decodeRef(buf []byte) (node, []byte, error) {
func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) {
kind, val, rest, err := rlp.Split(buf)
if err != nil {
return nil, buf, err
Expand All @@ -189,7 +189,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
return nil, buf, err
}
n, err := decodeNode(nil, buf)
n, err := decodeNode(nil, buf, cachegen)
return n, rest, err
case kind == rlp.String && len(val) == 0:
// empty node
Expand Down
2 changes: 1 addition & 1 deletion trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
if !bytes.Equal(sha.Sum(nil), wantHash) {
return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
}
n, err := decodeNode(wantHash, buf)
n, err := decodeNode(wantHash, buf, 0)
if err != nil {
return nil, fmt.Errorf("bad proof node %d: %v", i, err)
}
Expand Down
6 changes: 3 additions & 3 deletions trie/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c
}
key := root.Bytes()
blob, _ := s.database.Get(key)
if local, err := decodeNode(key, blob); local != nil && err == nil {
if local, err := decodeNode(key, blob, 0); local != nil && err == nil {
return
}
// Assemble the new sub-trie sync request
Expand Down Expand Up @@ -158,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) {
continue
}
// Decode the node data content and update the request
node, err := decodeNode(item.Hash[:], item.Data)
node, err := decodeNode(item.Hash[:], item.Data, 0)
if err != nil {
return i, err
}
Expand Down Expand Up @@ -246,7 +246,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
if node, ok := (*child.node).(hashNode); ok {
// Try to resolve the node from the local database
blob, _ := s.database.Get(node)
if local, err := decodeNode(node[:], blob); local != nil && err == nil {
if local, err := decodeNode(node[:], blob, 0); local != nil && err == nil {
*child.node = local
continue
}
Expand Down
11 changes: 7 additions & 4 deletions trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,15 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
if err == nil && didResolve {
n = n.copy()
n.Val = newnode
n.flags.gen = t.cachegen
}
return value, n, didResolve, err
case *fullNode:
value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
n = n.copy()
n.flags.gen = t.cachegen
n.Children[key[pos]] = newnode

}
return value, n, didResolve, err
case hashNode:
Expand Down Expand Up @@ -247,7 +248,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
return false, n, err
}
n = n.copy()
n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
n.flags = t.newFlag()
n.Children[key[0]] = nn
return true, n, nil

case nil:
Expand Down Expand Up @@ -331,7 +333,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
return false, n, err
}
n = n.copy()
n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
n.flags = t.newFlag()
n.Children[key[0]] = nn

// Check how many non-nil entries are left after deleting and
// reduce the full node to a short node if only one entry is
Expand Down Expand Up @@ -427,7 +430,7 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) {
SuffixLen: len(suffix),
}
}
dec := mustDecodeNode(n, enc)
dec := mustDecodeNode(n, enc, t.cachegen)
return dec, nil
}

Expand Down
91 changes: 70 additions & 21 deletions trie/trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,40 +300,63 @@ func TestReplication(t *testing.T) {
}
}

// Not an actual test
func TestOutput(t *testing.T) {
t.Skip()

base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
trie := newEmpty()
for i := 0; i < 50; i++ {
updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
}
fmt.Println("############################## FULL ################################")
fmt.Println(trie.root)

trie.Commit()
fmt.Println("############################## SMALL ################################")
trie2, _ := New(trie.Hash(), trie.db)
getString(trie2, base+"20")
fmt.Println(trie2.root)
}

func TestLargeValue(t *testing.T) {
trie := newEmpty()
trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
trie.Hash()
}

type countingDB struct {
Database
gets map[string]int
}

func (db *countingDB) Get(key []byte) ([]byte, error) {
db.gets[string(key)]++
return db.Database.Get(key)
}

// TestCacheUnload checks that decoded nodes are unloaded after a
// certain number of commit operations.
func TestCacheUnload(t *testing.T) {
// Create test trie with two branches.
trie := newEmpty()
key1 := "---------------------------------"
key2 := "---some other branch"
updateString(trie, key1, "this is the branch of key1.")
updateString(trie, key2, "this is the branch of key2.")
root, _ := trie.Commit()

// Commit the trie repeatedly and access key1.
// The branch containing it is loaded from DB exactly two times:
// in the 0th and 6th iteration.
db := &countingDB{Database: trie.db, gets: make(map[string]int)}
trie, _ = New(root, db)
trie.SetCacheLimit(5)
for i := 0; i < 12; i++ {
getString(trie, key1)
trie.Commit()
}

// Check that it got loaded two times.
for dbkey, count := range db.gets {
if count != 2 {
t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
}
}
}

// randTest performs random trie operations.
// Instances of this test are created by Generate.
type randTest []randTestStep

type randTestStep struct {
op int
key []byte // for opUpdate, opDelete, opGet
value []byte // for opUpdate
}

type randTest []randTestStep

const (
opUpdate = iota
opDelete
Expand All @@ -342,6 +365,7 @@ const (
opHash
opReset
opItercheckhash
opCheckCacheInvariant
opMax // boundary value, not an actual op
)

Expand Down Expand Up @@ -437,7 +461,32 @@ func runRandTest(rt randTest) bool {
fmt.Println("hashes not equal")
return false
}
case opCheckCacheInvariant:
return checkCacheInvariant(tr.root, tr.cachegen, 0)
}
}
return true
}

func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool {
switch n := n.(type) {
case *shortNode:
if n.flags.gen > parentCachegen {
fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
return false
}
return checkCacheInvariant(n.Val, n.flags.gen, depth+1)
case *fullNode:
if n.flags.gen > parentCachegen {
fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
return false
}
for _, child := range n.Children {
if !checkCacheInvariant(child, n.flags.gen, depth+1) {
return false
}
}
return true
}
return true
}
Expand Down

0 comments on commit 177cab5

Please sign in to comment.