Skip to content

Commit

Permalink
Add IsClosed method (dgraph-io#1475)
Browse files Browse the repository at this point in the history
Add an IsClosed method which denotes if badger instance is closed or not.
  • Loading branch information
Ibrahim Jarif authored Aug 26, 2020
1 parent 431aee1 commit 1e21a94
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 2 deletions.
12 changes: 12 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type DB struct {
logRotates int32

blockWrites int32
isClosed uint32

orc *oracle

Expand Down Expand Up @@ -475,6 +476,12 @@ func (db *DB) Close() error {
return err
}

// IsClosed denotes if the badger DB is closed or not. A DB instance should not
// be used after closing it.
func (db *DB) IsClosed() bool {
return atomic.LoadUint32(&db.isClosed) == 1
}

func (db *DB) close() (err error) {
db.opt.Debugf("Closing database")

Expand Down Expand Up @@ -556,6 +563,8 @@ func (db *DB) close() (err error) {
db.blockCache.Close()
db.indexCache.Close()

atomic.StoreUint32(&db.isClosed, 1)

if db.opt.InMemory {
return
}
Expand Down Expand Up @@ -647,6 +656,9 @@ func (db *DB) getMemTables() ([]*skl.Skiplist, func()) {
// been moved, then for the corresponding movekey, we'll look through all the levels of the tree
// to ensure that we pick the highest version of the movekey present.
func (db *DB) get(key []byte) (y.ValueStruct, error) {
if db.IsClosed() {
return y.ValueStruct{}, ErrDBClosed
}
tables, decr := db.getMemTables() // Lock should be released.
defer decr()

Expand Down
30 changes: 30 additions & 0 deletions db2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,33 @@ func TestDropAllDropPrefix(t *testing.T) {
wg.Wait()
})
}

func TestIsClosed(t *testing.T) {
test := func(inMemory bool) {
opt := DefaultOptions("")
if inMemory {
opt.InMemory = true
} else {
dir, err := ioutil.TempDir("", "badger-test")
require.NoError(t, err)
defer removeDir(dir)

opt.Dir = dir
opt.ValueDir = dir
}

db, err := Open(opt)
require.NoError(t, err)
require.False(t, db.IsClosed())
require.NoError(t, db.Close())
require.True(t, db.IsClosed())
}

t.Run("normal", func(t *testing.T) {
test(false)
})
t.Run("in-memory", func(t *testing.T) {
test(true)
})

}
6 changes: 5 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ var (
// ErrInvalidDataKeyID is returned if the datakey id is invalid.
ErrInvalidDataKeyID = errors.New("Invalid datakey id")

// ErrInvalidEncryptionKey is returned if length of encryption keys is invalid.
ErrInvalidEncryptionKey = errors.New("Encryption key's length should be" +
"either 16, 24, or 32 bytes")

// ErrGCInMemoryMode is returned when db.RunValueLogGC is called in in-memory mode.
ErrGCInMemoryMode = errors.New("Cannot run value log GC when DB is opened in InMemory mode")

// ErrDBClosed is returned when a get operation is performed after closing the DB.
ErrDBClosed = errors.New("DB Closed")
)
3 changes: 3 additions & 0 deletions iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ func (txn *Txn) NewIterator(opt IteratorOptions) *Iterator {
if txn.discarded {
panic("Transaction has already been discarded")
}
if txn.db.IsClosed() {
panic(ErrDBClosed.Error())
}

// Keep track of the number of active iterators.
atomic.AddInt32(&txn.numIterators, 1)
Expand Down
3 changes: 3 additions & 0 deletions levels.go
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,9 @@ func (s *levelsController) close() error {
// get returns the found value if any. If not found, we return nil.
func (s *levelsController) get(key []byte, maxVs *y.ValueStruct, startLevel int) (
y.ValueStruct, error) {
if s.kv.IsClosed() {
return y.ValueStruct{}, ErrDBClosed
}
// It's important that we iterate the levels from 0 on upward. The reason is, if we iterated
// in opposite order, or in parallel (naively calling all the h.RLock() in some order) we could
// read level L's tables post-compaction and level L+1's tables pre-compaction. (If we do
Expand Down
2 changes: 1 addition & 1 deletion managed_db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func TestWriteAfterClose(t *testing.T) {
err = db.Update(func(txn *Txn) error {
return txn.SetEntry(NewEntry([]byte("a"), []byte("b")))
})
require.Equal(t, ErrBlockedWrites, err)
require.Equal(t, ErrDBClosed, err)
}

func TestDropAllRace(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,9 @@ func (db *DB) newTransaction(update, isManaged bool) *Txn {
// returned by the function is relayed by the View method.
// If View is used with managed transactions, it would assume a read timestamp of MaxUint64.
func (db *DB) View(fn func(txn *Txn) error) error {
if db.IsClosed() {
return ErrDBClosed
}
var txn *Txn
if db.opt.managedTxns {
txn = db.NewTransactionAt(math.MaxUint64, false)
Expand All @@ -803,6 +806,9 @@ func (db *DB) View(fn func(txn *Txn) error) error {
// for the user. Error returned by the function is relayed by the Update method.
// Update cannot be used with managed transactions.
func (db *DB) Update(fn func(txn *Txn) error) error {
if db.IsClosed() {
return ErrDBClosed
}
if db.opt.managedTxns {
panic("Update can only be used with managedDB=false.")
}
Expand Down

0 comments on commit 1e21a94

Please sign in to comment.