Skip to content

Commit

Permalink
Add callback create, delete
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jan 26, 2014
1 parent 192ed06 commit 973acd6
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 68 deletions.
24 changes: 13 additions & 11 deletions callback.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package gorm

import "fmt"
import (
"fmt"
)

type callback struct {
creates []*func()
updates []*func()
deletes []*func()
queries []*func()
creates []*func(scope *Scope)
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
processors []*callback_processor
}

Expand All @@ -17,7 +19,7 @@ type callback_processor struct {
replace bool
remove bool
typ string
processor *func()
processor *func(scope *Scope)
callback *callback
}

Expand Down Expand Up @@ -53,7 +55,7 @@ func (cp *callback_processor) After(name string) *callback_processor {
return cp
}

func (cp *callback_processor) Register(name string, fc func()) {
func (cp *callback_processor) Register(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.callback.sort()
Expand All @@ -65,7 +67,7 @@ func (cp *callback_processor) Remove(name string) {
cp.callback.sort()
}

func (cp *callback_processor) Replace(name string, fc func()) {
func (cp *callback_processor) Replace(name string, fc func(scope *Scope)) {
cp.name = name
cp.processor = &fc
cp.replace = true
Expand All @@ -81,7 +83,7 @@ func getRIndex(strs []string, str string) int {
return -1
}

func sortProcessors(cps []*callback_processor) []*func() {
func sortProcessors(cps []*callback_processor) []*func(scope *Scope) {
var sortCallbackProcessor func(c *callback_processor, force bool)
var names, sortedNames = []string{}, []string{}

Expand Down Expand Up @@ -137,8 +139,8 @@ func sortProcessors(cps []*callback_processor) []*func() {
sortCallbackProcessor(cp, false)
}

var funcs = []*func(){}
var sortedFuncs = []*func(){}
var funcs = []*func(scope *Scope){}
var sortedFuncs = []*func(scope *Scope){}
for _, name := range sortedNames {
index := getRIndex(names, name)
if !cps[index].remove {
Expand Down
41 changes: 41 additions & 0 deletions callback_create.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package gorm

func BeforeCreate(scope *Scope) {
scope.CallMethod("BeforeSave")
scope.CallMethod("BeforeCreate")
}

func SaveBeforeAssociations(scope *Scope) {
}

func Create(scope *Scope) {
if !scope.HasError() {
var id interface{}
if scope.Dialect().SupportLastInsertId() {
if sql_result, err := scope.DB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
id, err = sql_result.LastInsertId()
scope.Err(err)
}
} else {
scope.Err(scope.DB().QueryRow(scope.Sql, scope.SqlVars...).Scan(&id))
}

scope.SetColumn(scope.PrimaryKey(), id)
}
}

func AfterCreate(scope *Scope) {
scope.CallMethod("AfterCreate")
scope.CallMethod("AfterSave")
}

func SaveAfterAssociations(scope *Scope) {
}

func init() {
DefaultCallback.Create().Register("before_create", BeforeCreate)
DefaultCallback.Create().Register("save_before_associations", SaveBeforeAssociations)
DefaultCallback.Create().Register("create", Create)
DefaultCallback.Create().Register("save_after_associations", SaveAfterAssociations)
DefaultCallback.Create().Register("after_create", AfterCreate)
}
33 changes: 33 additions & 0 deletions callback_delete.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package gorm

import (
"fmt"
"time"
)

func BeforeDelete(scope *Scope) {
scope.CallMethod("BeforeDelete")
}

func Delete(scope *Scope) {
if scope.HasError() {
return
}

if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.TableName(), scope.AddToVars(time.Now()), scope.CombinedConditionSql()))
} else {
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.TableName(), scope.CombinedConditionSql()))
}
scope.Exec()
}

func AfterDelete(scope *Scope) {
scope.CallMethod("AfterDelete")
}

func init() {
DefaultCallback.Delete().Register("before_delete", BeforeDelete)
DefaultCallback.Delete().Register("delete", Delete)
DefaultCallback.Delete().Register("after_delete", AfterDelete)
}
14 changes: 7 additions & 7 deletions callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"testing"
)

func equalFuncs(funcs []*func(), fnames []string) bool {
func equalFuncs(funcs []*func(s *Scope), fnames []string) bool {
var names []string
for _, f := range funcs {
fnames := strings.Split(runtime.FuncForPC(reflect.ValueOf(*f).Pointer()).Name(), ".")
Expand All @@ -16,11 +16,11 @@ func equalFuncs(funcs []*func(), fnames []string) bool {
return reflect.DeepEqual(names, fnames)
}

func create() {}
func before_create1() {}
func before_create2() {}
func after_create1() {}
func after_create2() {}
func create(s *Scope) {}
func before_create1(s *Scope) {}
func before_create2(s *Scope) {}
func after_create1(s *Scope) {}
func after_create2(s *Scope) {}

func TestRegisterCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}}
Expand Down Expand Up @@ -76,7 +76,7 @@ func TestRegisterCallbackWithComplexOrder2(t *testing.T) {
}
}

func replace_create() {}
func replace_create(s *Scope) {}

func TestReplaceCallback(t *testing.T) {
var callback = &callback{processors: []*callback_processor{}}
Expand Down
36 changes: 2 additions & 34 deletions callbacks/create.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
package callback
package callbacks

import (
"fmt"

"github.com/jinzhu/gorm"

"time"
)
import "github.com/jinzhu/gorm"

func Create(scope *gorm.Scope) {
}
Expand All @@ -15,32 +9,6 @@ func init() {
gorm.DefaultCallback.Create().Before().Register(Create)
}

func query(db *DB) {
}

func save(db *DB) {
}

func create(db *DB) {
}

func update(db *DB) {
}

func Delete(scope *Scope) {
scope.CallMethod("BeforeDelete")

if !scope.HasError() {
if !scope.Search.unscope && scope.HasColumn("DeletedAt") {
scope.Raw(fmt.Sprintf("UPDATE %v SET deleted_at=%v %v", scope.Table(), scope.AddToVars(time.Now()), scope.CombinedSql()))
} else {
scope.Raw(fmt.Sprintf("DELETE FROM %v %v", scope.Table(), scope.CombinedSql()))
}
scope.Exec()
scope.CallMethod("AfterDelete")
}
}

func init() {
DefaultCallback.Create().Before("Delete").After("Lalala").Register("delete", Delete)
DefaultCallback.Update().Before("Delete").After("Lalala").Remove("replace", Delete)
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

type DB struct {
Value interface{}
Callbacks *callback
callback *callback
Error error
db sqlCommon
parent *DB
Expand All @@ -22,7 +22,7 @@ type DB struct {

func Open(driver, source string) (DB, error) {
var err error
db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger}
db := DB{dialect: dialect.New(driver), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback}
db.db, err = sql.Open(driver, source)
db.parent = &db
return db, err
Expand Down
14 changes: 0 additions & 14 deletions private.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package gorm

import (
"fmt"
"os"
"regexp"
"runtime"
"strings"
"time"
)

Expand Down Expand Up @@ -55,16 +51,6 @@ func (s *DB) hasError() bool {
return s.Error != nil
}

func fileWithLineNum() string {
for i := 1; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
}
}
return ""
}

func (s *DB) print(v ...interface{}) {
s.parent.logger.(logger).Print(v...)
}
Expand Down
61 changes: 61 additions & 0 deletions scope.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package gorm

import "github.com/jinzhu/gorm/dialect"

type Scope struct {
Search *search
Sql string
SqlVars []interface{}
db *DB
}

func (scope *Scope) DB() sqlCommon {
return scope.db.db
}

func (scope *Scope) Dialect() dialect.Dialect {
return scope.db.parent.dialect
}

func (scope *Scope) Err(err error) error {
if err != nil {
scope.db.err(err)
}
return err
}

func (scope *Scope) HasError() bool {
return true
}

func (scope *Scope) PrimaryKey() string {
return ""
}

func (scope *Scope) HasColumn(name string) bool {
return false
}

func (scope *Scope) SetColumn(column string, value interface{}) {
}

func (scope *Scope) CallMethod(name string) {
}

func (scope *Scope) CombinedConditionSql() string {
return ""
}

func (scope *Scope) AddToVars(value interface{}) string {
return ""
}

func (scope *Scope) TableName() string {
return ""
}

func (scope *Scope) Raw(sql string, values ...interface{}) {
}

func (scope *Scope) Exec() {
}
14 changes: 14 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package gorm
import (
"bytes"
"database/sql"
"fmt"
"os"
"reflect"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -86,6 +90,16 @@ func toSearchableMap(attrs ...interface{}) (result interface{}) {
return
}

func fileWithLineNum() string {
for i := 1; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) {
return fmt.Sprintf("%v:%v", strings.TrimPrefix(file, os.Getenv("GOPATH")+"src/"), line)
}
}
return ""
}

func setFieldValue(field reflect.Value, value interface{}) bool {
if field.IsValid() && field.CanAddr() {
switch field.Kind() {
Expand Down

0 comments on commit 973acd6

Please sign in to comment.