diff --git a/main.go b/main.go index ca1d24bb7..f205b6c39 100644 --- a/main.go +++ b/main.go @@ -28,10 +28,22 @@ func Open(driver, source string) (DB, error) { return db, err } +// Return the underlying sql.DB instance. +// +// If called inside a transaction, it will panic. +// Use Tx() instead in this case. func (s *DB) DB() *sql.DB { return s.db.(*sql.DB) } +// Return the underlying sql.Tx instance. +// +// If called outside of a transaction, it will panic. +// Use DB() instead in this case. +func (s *DB) Tx() *sql.Tx { + return s.db.(*sql.Tx) +} + func (s *DB) Callback() *callback { s.parent.callback = s.parent.callback.clone() return s.parent.callback diff --git a/main_test.go b/main_test.go index 47611908b..14de67010 100644 --- a/main_test.go +++ b/main_test.go @@ -1542,6 +1542,11 @@ func TestTransaction(t *testing.T) { t.Errorf("Should find saved record, but got", err) } + sql_tx := tx.Tx() // This shouldn't panic. + if sql_tx == nil { + t.Errorf("Should return the underlying sql.Tx, but got nil") + } + tx.Rollback() if err := tx.First(&User{}, "name = ?", "transcation").Error; err == nil {