Skip to content

Commit

Permalink
Merge pull request livepeer#1135 from livepeer/nv/datadir-reuse
Browse files Browse the repository at this point in the history
common, db: add network id check to prevent datadir collisions
  • Loading branch information
kyriediculous authored Oct 25, 2019
2 parents 053db68 + 2f6f811 commit e4bdc19
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 3 deletions.
43 changes: 41 additions & 2 deletions cmd/livepeer/livepeer.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ func main() {
if *ethController == "" {
*ethController = netw.ethController
}
glog.Infof("***Livepeer is running on the %v*** network: %v***", *network, *ethController)
glog.Infof("***Livepeer is running on the %v network: %v***", *network, *ethController)
} else {
glog.Infof("***Livepeer is running on the %v*** network", *network)
glog.Infof("***Livepeer is running on the %v network***", *network)
}

if *datadir == "" {
Expand Down Expand Up @@ -288,6 +288,12 @@ func main() {
watcherErr := make(chan error)
if *network == "offchain" {
glog.Infof("***Livepeer is in off-chain mode***")

if err := checkOrStoreChainID(dbh, big.NewInt(0)); err != nil {
glog.Error(err)
return
}

} else {
var keystoreDir string
if _, err := os.Stat(*ethKeystorePath); !os.IsNotExist(err) {
Expand All @@ -314,6 +320,16 @@ func main() {
return
}

chainID, err := backend.ChainID(context.Background())
if err != nil {
glog.Errorf("failed to get chain ID from remote ethereum node: %v", err)
return
}
if err := checkOrStoreChainID(dbh, chainID); err != nil {
glog.Error(err)
return
}

client, err := eth.NewClient(ethcommon.HexToAddress(*ethAcctAddr), keystoreDir, backend, ethcommon.HexToAddress(*ethController), EthTxTimeout)
if err != nil {
glog.Errorf("Failed to create client: %v", err)
Expand Down Expand Up @@ -815,3 +831,26 @@ func defaultAddr(addr, defaultHost, defaultPort string) string {
}
return addr
}

func checkOrStoreChainID(dbh *common.DB, chainID *big.Int) error {
expectedChainID, err := dbh.ChainID()
if err != nil {
return err
}

if expectedChainID == nil {
// No chainID stored yet
// Store the provided chainID and skip the check
if err := dbh.SetChainID(chainID); err != nil {
return err
}

return nil
}

if expectedChainID.Cmp(chainID) != 0 {
return fmt.Errorf("expecting chainID of %v, but got %v. Did you change networks without changing network name or datadir?", expectedChainID, chainID)
}

return nil
}
61 changes: 60 additions & 1 deletion common/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type DB struct {
// prepared statements
selectOrchs *sql.Stmt
updateOrch *sql.Stmt
selectKV *sql.Stmt
updateKV *sql.Stmt
insertUnbondingLock *sql.Stmt
deleteUnbondingLock *sql.Stmt
Expand Down Expand Up @@ -169,8 +170,17 @@ func InitDB(dbPath string) (*DB, error) {
}
d.updateOrch = stmt

// selectKV prepared statement
stmt, err = db.Prepare("SELECT value FROM kv WHERE key=?")
if err != nil {
glog.Error("Unable to prepare selectKV stmt", err)
d.Close()
return nil, err
}
d.selectKV = stmt

// updateKV prepared statement
stmt, err = db.Prepare("UPDATE kv SET value=?, updatedAt = datetime() WHERE key=?")
stmt, err = db.Prepare("INSERT OR REPLACE INTO kv(key, value, updatedAt) VALUES(?1, ?2, datetime())")
if err != nil {
glog.Error("Unable to prepare updatekv stmt ", err)
d.Close()
Expand Down Expand Up @@ -272,6 +282,9 @@ func (db *DB) Close() {
if db.selectOrchs != nil {
db.selectOrchs.Close()
}
if db.selectKV != nil {
db.selectKV.Close()
}
if db.updateKV != nil {
db.updateKV.Close()
}
Expand Down Expand Up @@ -323,6 +336,52 @@ func (db *DB) LastSeenBlock() (*big.Int, error) {
return header.Number, nil
}

func (db *DB) ChainID() (*big.Int, error) {
idString, err := db.selectKVStore("chainID")
if err != nil {
return nil, err
}

if idString == "" {
return nil, nil
}

id, ok := new(big.Int).SetString(idString, 10)
if !ok {
return nil, fmt.Errorf("unable to convert chainID string to big.Int")
}

return id, nil
}

func (db *DB) SetChainID(id *big.Int) error {
if err := db.updateKVStore("chainID", id.String()); err != nil {
return err
}
return nil
}

func (db *DB) selectKVStore(key string) (string, error) {
row := db.selectKV.QueryRow(key)
var valueString string
if err := row.Scan(&valueString); err != nil {
if err.Error() != "sql: no rows in result set" {
return "", fmt.Errorf("could not retrieve key from database: %v", err)
}
// If there is no result return no error, just zero value
return "", nil
}
return valueString, nil
}

func (db *DB) updateKVStore(key, value string) error {
_, err := db.updateKV.Exec(key, value)
if err != nil {
glog.Errorf("db: Unable to update %v in database: %v", key, err)
}
return err
}

func (db *DB) UpdateOrch(orch *DBOrch) error {
if db == nil || orch == nil || orch.ServiceURI == "" || orch.EthereumAddr == "" {
return nil
Expand Down
83 changes: 83 additions & 0 deletions common/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,89 @@ import (
"github.com/stretchr/testify/require"
)

func TestUpdateKVStore(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
expectedChainID := "1337"

dbh, dbraw, err := TempDB(t)
require.Nil(err)

defer dbh.Close()
defer dbraw.Close()

var chainID string
row := dbraw.QueryRow("SELECT value FROM kv WHERE key = 'chainID'")
err = row.Scan(&chainID)
assert.EqualError(err, "sql: no rows in result set")
assert.Equal("", chainID)

dbh.updateKVStore("chainID", expectedChainID)
row = dbraw.QueryRow("SELECT value FROM kv WHERE key = 'chainID'")
err = row.Scan(&chainID)
assert.Nil(err)
assert.Equal(expectedChainID, chainID)
}

func TestSelectKVStore(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
key := "foo"
value := "bar"

dbh, dbraw, err := TempDB(t)
require.Nil(err)
defer dbh.Close()
defer dbraw.Close()

err = dbh.updateKVStore(key, value)
require.Nil(err)

val, err := dbh.selectKVStore(key)
assert.Nil(err)
assert.Equal(val, value)
}

func TestChainID(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
expectedChainID := "1337"

dbh, dbraw, err := TempDB(t)
require.Nil(err)

defer dbh.Close()
defer dbraw.Close()

expectedChainIDInt, ok := new(big.Int).SetString(expectedChainID, 10)
require.True(ok)
dbh.SetChainID(expectedChainIDInt)

chainID, err := dbh.ChainID()
assert.Nil(err)
assert.Equal(chainID.String(), expectedChainID)
}

func TestSetChainID(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
expectedChainID := "1337"

dbh, dbraw, err := TempDB(t)
require.Nil(err)

defer dbh.Close()
defer dbraw.Close()

expectedChainIDInt, ok := new(big.Int).SetString(expectedChainID, 10)
require.True(ok)
dbh.SetChainID(expectedChainIDInt)

chainID, err := dbh.ChainID()
assert.Nil(err)
assert.Equal(chainID, expectedChainIDInt)
}

func TestDBLastSeenBlock(t *testing.T) {
dbh, dbraw, err := TempDB(t)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions test_args.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ run_lp -broadcaster -datadir "$CUSTOM_DATADIR"
[ ! -d "$CUSTOM_DATADIR"/offchain ] # sanity check that network isn't included
kill $pid

CUSTOM_DATADIR="$TMPDIR"/customDatadir2

# sanity check that custom datadir does not exist
[ ! -d "$CUSTOM_DATADIR" ]

# check custom datadir with a network
run_lp -broadcaster -datadir "$CUSTOM_DATADIR" -network rinkeby $ETH_ARGS
[ ! -d "$CUSTOM_DATADIR"/rinkeby ] # sanity check that network isn't included
Expand Down

0 comments on commit e4bdc19

Please sign in to comment.