Skip to content

Commit

Permalink
Add DB.Tx() method to provice access to the underlying sql.Tx instance.
Browse files Browse the repository at this point in the history
  • Loading branch information
Timothy Stranex authored and Timothy Stranex committed Mar 16, 2014
1 parent d232c69 commit a336f51
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
12 changes: 12 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,22 @@ func Open(driver, source string) (DB, error) {
return db, err
}

// Return the underlying sql.DB instance.
//
// If called inside a transaction, it will panic.
// Use Tx() instead in this case.
func (s *DB) DB() *sql.DB {
return s.db.(*sql.DB)
}

// Return the underlying sql.Tx instance.
//
// If called outside of a transaction, it will panic.
// Use DB() instead in this case.
func (s *DB) Tx() *sql.Tx {
return s.db.(*sql.Tx)
}

func (s *DB) Callback() *callback {
s.parent.callback = s.parent.callback.clone()
return s.parent.callback
Expand Down
5 changes: 5 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,11 @@ func TestTransaction(t *testing.T) {
t.Errorf("Should find saved record, but got", err)
}

sql_tx := tx.Tx() // This shouldn't panic.
if sql_tx == nil {
t.Errorf("Should return the underlying sql.Tx, but got nil")
}

tx.Rollback()

if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {
Expand Down

0 comments on commit a336f51

Please sign in to comment.