From 32916a97aa2fc7fb57f77a78df2202560d78b2e4 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Fri, 7 May 2021 09:47:02 +0300 Subject: [PATCH] Restore hooks --- bun.go | 50 +++++---- internal/dbtest/model_hook_test.go | 164 ++++++++++++++--------------- model_hook.go | 50 ++++----- model_table_slice.go | 12 +-- model_table_struct.go | 12 +-- query_delete.go | 34 ++++-- query_insert.go | 34 ++++-- query_select.go | 31 ++++-- query_update.go | 34 ++++-- schema/hook.go | 6 +- 10 files changed, 232 insertions(+), 195 deletions(-) diff --git a/bun.go b/bun.go index 1962e604e..5994e8757 100644 --- a/bun.go +++ b/bun.go @@ -1,8 +1,6 @@ package bun import ( - "context" - "github.com/uptrace/bun/schema" "github.com/uptrace/bun/sqlfmt" ) @@ -28,36 +26,36 @@ type ( AfterDeleteHook = schema.AfterDeleteHook ) -type BeforeSelectQueryHook interface { - BeforeSelectQuery(ctx context.Context, query *SelectQuery) error -} +// type BeforeSelectQueryHook interface { +// BeforeSelectQuery(ctx context.Context, query *SelectQuery) error +// } -type AfterSelectQueryHook interface { - AfterSelectQuery(ctx context.Context, query *SelectQuery) error -} +// type AfterSelectQueryHook interface { +// AfterSelectQuery(ctx context.Context, query *SelectQuery) error +// } -type BeforeInsertQueryHook interface { - BeforeInsertQuery(ctx context.Context, query *InsertQuery) error -} +// type BeforeInsertQueryHook interface { +// BeforeInsertQuery(ctx context.Context, query *InsertQuery) error +// } -type AfterInsertQueryHook interface { - AfterInsertQuery(ctx context.Context, query *InsertQuery) error -} +// type AfterInsertQueryHook interface { +// AfterInsertQuery(ctx context.Context, query *InsertQuery) error +// } -type BeforeUpdateQueryHook interface { - BeforeUpdateQuery(ctx context.Context, query *UpdateQuery) error -} +// type BeforeUpdateQueryHook interface { +// BeforeUpdateQuery(ctx context.Context, query *UpdateQuery) error +// } -type AfterUpdateQueryHook interface { - AfterUpdateQuery(ctx context.Context, query *UpdateQuery) error -} +// type AfterUpdateQueryHook interface { +// AfterUpdateQuery(ctx context.Context, query *UpdateQuery) error +// } -type BeforeDeleteQueryHook interface { - BeforeDeleteQuery(ctx context.Context, query *DeleteQuery) error -} +// type BeforeDeleteQueryHook interface { +// BeforeDeleteQuery(ctx context.Context, query *DeleteQuery) error +// } -type AfterDeleteQueryHook interface { - AfterDeleteQuery(ctx context.Context, query *DeleteQuery) error -} +// type AfterDeleteQueryHook interface { +// AfterDeleteQuery(ctx context.Context, query *DeleteQuery) error +// } type BaseTable struct{} diff --git a/internal/dbtest/model_hook_test.go b/internal/dbtest/model_hook_test.go index aa007d8e8..0e02f1c85 100644 --- a/internal/dbtest/model_hook_test.go +++ b/internal/dbtest/model_hook_test.go @@ -53,7 +53,7 @@ func testModelHook(t *testing.T, db *bun.DB) { hook := &ModelHookTest{ID: 1} _, err := db.NewInsert().Model(hook).Exec(ctx) require.NoError(t, err) - require.Equal(t, []string{"BeforeInsertQuery", "AfterInsertQuery"}, events.Flush()) + require.Equal(t, []string{"BeforeInsert", "AfterInsert"}, events.Flush()) } { @@ -61,10 +61,9 @@ func testModelHook(t *testing.T, db *bun.DB) { err := db.NewSelect().Model(hook).Scan(ctx) require.NoError(t, err) require.Equal(t, []string{ - "BeforeSelectQuery", "BeforeScan", "AfterScan", - "AfterSelectQuery", + "AfterSelect", }, events.Flush()) } @@ -73,10 +72,9 @@ func testModelHook(t *testing.T, db *bun.DB) { err := db.NewSelect().Model(&hooks).Scan(ctx) require.NoError(t, err) require.Equal(t, []string{ - "BeforeSelectQuery", "BeforeScan", "AfterScan", - "AfterSelectQuery", + "AfterSelect", }, events.Flush()) } @@ -84,20 +82,20 @@ func testModelHook(t *testing.T, db *bun.DB) { hook := &ModelHookTest{ID: 1} _, err := db.NewUpdate().Model(hook).WherePK().Exec(ctx) require.NoError(t, err) - require.Equal(t, []string{"BeforeUpdateQuery", "AfterUpdateQuery"}, events.Flush()) + require.Equal(t, []string{"BeforeUpdate", "AfterUpdate"}, events.Flush()) } { hook := &ModelHookTest{ID: 1} _, err := db.NewDelete().Model(hook).WherePK().Exec(ctx) require.NoError(t, err) - require.Equal(t, []string{"BeforeDeleteQuery", "AfterDeleteQuery"}, events.Flush()) + require.Equal(t, []string{"BeforeDelete", "AfterDelete"}, events.Flush()) } { _, err := db.NewDelete().Model((*ModelHookTest)(nil)).Where("TRUE").Exec(ctx) require.NoError(t, err) - require.Equal(t, []string{"BeforeDeleteQuery", "AfterDeleteQuery"}, events.Flush()) + require.Nil(t, events.Flush()) } } @@ -120,107 +118,107 @@ func (t *ModelHookTest) AfterScan(c context.Context) error { return nil } -var _ bun.BeforeSelectQueryHook = (*ModelHookTest)(nil) +// var _ bun.BeforeSelectQueryHook = (*ModelHookTest)(nil) -func (t *ModelHookTest) BeforeSelectQuery(ctx context.Context, query *bun.SelectQuery) error { - events.Add("BeforeSelectQuery") - return nil -} - -var _ bun.AfterSelectQueryHook = (*ModelHookTest)(nil) - -func (t *ModelHookTest) AfterSelectQuery(ctx context.Context, query *bun.SelectQuery) error { - events.Add("AfterSelectQuery") - return nil -} +// func (t *ModelHookTest) BeforeSelectQuery(ctx context.Context, query *bun.SelectQuery) error { +// events.Add("BeforeSelectQuery") +// return nil +// } -var _ bun.BeforeUpdateQueryHook = (*ModelHookTest)(nil) +// var _ bun.AfterSelectQueryHook = (*ModelHookTest)(nil) -func (t *ModelHookTest) BeforeUpdateQuery(ctx context.Context, query *bun.UpdateQuery) error { - events.Add("BeforeUpdateQuery") - return nil -} +// func (t *ModelHookTest) AfterSelectQuery(ctx context.Context, query *bun.SelectQuery) error { +// events.Add("AfterSelectQuery") +// return nil +// } -var _ bun.AfterUpdateQueryHook = (*ModelHookTest)(nil) +// var _ bun.BeforeUpdateQueryHook = (*ModelHookTest)(nil) -func (t *ModelHookTest) AfterUpdateQuery(ctx context.Context, query *bun.UpdateQuery) error { - events.Add("AfterUpdateQuery") - return nil -} +// func (t *ModelHookTest) BeforeUpdateQuery(ctx context.Context, query *bun.UpdateQuery) error { +// events.Add("BeforeUpdateQuery") +// return nil +// } -var _ bun.BeforeInsertQueryHook = (*ModelHookTest)(nil) +// var _ bun.AfterUpdateQueryHook = (*ModelHookTest)(nil) -func (t *ModelHookTest) BeforeInsertQuery(ctx context.Context, query *bun.InsertQuery) error { - events.Add("BeforeInsertQuery") - return nil -} +// func (t *ModelHookTest) AfterUpdateQuery(ctx context.Context, query *bun.UpdateQuery) error { +// events.Add("AfterUpdateQuery") +// return nil +// } -var _ bun.AfterInsertQueryHook = (*ModelHookTest)(nil) +// var _ bun.BeforeInsertQueryHook = (*ModelHookTest)(nil) -func (t *ModelHookTest) AfterInsertQuery(ctx context.Context, query *bun.InsertQuery) error { - events.Add("AfterInsertQuery") - return nil -} +// func (t *ModelHookTest) BeforeInsertQuery(ctx context.Context, query *bun.InsertQuery) error { +// events.Add("BeforeInsertQuery") +// return nil +// } -var _ bun.BeforeDeleteQueryHook = (*ModelHookTest)(nil) +// var _ bun.AfterInsertQueryHook = (*ModelHookTest)(nil) -func (t *ModelHookTest) BeforeDeleteQuery(ctx context.Context, query *bun.DeleteQuery) error { - events.Add("BeforeDeleteQuery") - return nil -} +// func (t *ModelHookTest) AfterInsertQuery(ctx context.Context, query *bun.InsertQuery) error { +// events.Add("AfterInsertQuery") +// return nil +// } -var _ bun.AfterDeleteQueryHook = (*ModelHookTest)(nil) +// var _ bun.BeforeDeleteQueryHook = (*ModelHookTest)(nil) -func (t *ModelHookTest) AfterDeleteQuery(ctx context.Context, query *bun.DeleteQuery) error { - events.Add("AfterDeleteQuery") - return nil -} +// func (t *ModelHookTest) BeforeDeleteQuery(ctx context.Context, query *bun.DeleteQuery) error { +// events.Add("BeforeDeleteQuery") +// return nil +// } -// var _ bun.AfterSelectHook = (*ModelHookTest)(nil) +// var _ bun.AfterDeleteQueryHook = (*ModelHookTest)(nil) -// func (t *ModelHookTest) AfterSelect(c context.Context) error { -// t.events = append(t.events, "AfterSelect") +// func (t *ModelHookTest) AfterDeleteQuery(ctx context.Context, query *bun.DeleteQuery) error { +// events.Add("AfterDeleteQuery") // return nil // } -// var _ bun.BeforeInsertHook = (*ModelHookTest)(nil) +var _ bun.AfterSelectHook = (*ModelHookTest)(nil) -// func (t *ModelHookTest) BeforeInsert(c context.Context) (context.Context, error) { -// t.events = append(t.events, "BeforeInsert") -// return c, nil -// } +func (t *ModelHookTest) AfterSelect(c context.Context) error { + events.Add("AfterSelect") + return nil +} -// var _ bun.AfterInsertHook = (*ModelHookTest)(nil) +var _ bun.BeforeInsertHook = (*ModelHookTest)(nil) -// func (t *ModelHookTest) AfterInsert(c context.Context) error { -// t.events = append(t.events, "AfterInsert") -// return nil -// } +func (t *ModelHookTest) BeforeInsert(ctx context.Context) error { + events.Add("BeforeInsert") + return nil +} -// var _ bun.BeforeUpdateHook = (*ModelHookTest)(nil) +var _ bun.AfterInsertHook = (*ModelHookTest)(nil) -// func (t *ModelHookTest) BeforeUpdate(c context.Context) (context.Context, error) { -// t.events = append(t.events, "BeforeUpdate") -// return c, nil -// } +func (t *ModelHookTest) AfterInsert(c context.Context) error { + events.Add("AfterInsert") + return nil +} -// var _ bun.AfterUpdateHook = (*ModelHookTest)(nil) +var _ bun.BeforeUpdateHook = (*ModelHookTest)(nil) -// func (t *ModelHookTest) AfterUpdate(c context.Context) error { -// t.events = append(t.events, "AfterUpdate") -// return nil -// } +func (t *ModelHookTest) BeforeUpdate(ctx context.Context) error { + events.Add("BeforeUpdate") + return nil +} -// var _ bun.BeforeDeleteHook = (*ModelHookTest)(nil) +var _ bun.AfterUpdateHook = (*ModelHookTest)(nil) -// func (t *ModelHookTest) BeforeDelete(c context.Context) (context.Context, error) { -// t.events = append(t.events, "BeforeDelete") -// return c, nil -// } +func (t *ModelHookTest) AfterUpdate(c context.Context) error { + events.Add("AfterUpdate") + return nil +} -// var _ bun.AfterDeleteHook = (*ModelHookTest)(nil) +var _ bun.BeforeDeleteHook = (*ModelHookTest)(nil) -// func (t *ModelHookTest) AfterDelete(c context.Context) error { -// t.events = append(t.events, "AfterDelete") -// return nil -// } +func (t *ModelHookTest) BeforeDelete(ctx context.Context) error { + events.Add("BeforeDelete") + return nil +} + +var _ bun.AfterDeleteHook = (*ModelHookTest)(nil) + +func (t *ModelHookTest) AfterDelete(c context.Context) error { + events.Add("AfterDelete") + return nil +} diff --git a/model_hook.go b/model_hook.go index dabd61296..4182167b1 100644 --- a/model_hook.go +++ b/model_hook.go @@ -19,20 +19,20 @@ var ( _ AfterDeleteHook = (*hookStubs)(nil) ) -func (hookStubs) AfterSelect(ctx context.Context) error { return nil } -func (hookStubs) BeforeInsert(ctx context.Context) (context.Context, error) { return ctx, nil } -func (hookStubs) AfterInsert(ctx context.Context) error { return nil } -func (hookStubs) BeforeUpdate(ctx context.Context) (context.Context, error) { return ctx, nil } -func (hookStubs) AfterUpdate(ctx context.Context) error { return nil } -func (hookStubs) BeforeDelete(ctx context.Context) (context.Context, error) { return ctx, nil } -func (hookStubs) AfterDelete(ctx context.Context) error { return nil } +func (hookStubs) AfterSelect(ctx context.Context) error { return nil } +func (hookStubs) BeforeInsert(ctx context.Context) error { return nil } +func (hookStubs) AfterInsert(ctx context.Context) error { return nil } +func (hookStubs) BeforeUpdate(ctx context.Context) error { return nil } +func (hookStubs) AfterUpdate(ctx context.Context) error { return nil } +func (hookStubs) BeforeDelete(ctx context.Context) error { return nil } +func (hookStubs) AfterDelete(ctx context.Context) error { return nil } func callHookSlice( ctx context.Context, slice reflect.Value, ptr bool, - hook func(context.Context, reflect.Value) (context.Context, error), -) (context.Context, error) { + hook func(context.Context, reflect.Value) error, +) error { var firstErr error sliceLen := slice.Len() for i := 0; i < sliceLen; i++ { @@ -41,13 +41,11 @@ func callHookSlice( v = v.Addr() } - var err error - ctx, err = hook(ctx, v) - if err != nil && firstErr == nil { + if err := hook(ctx, v); err != nil && firstErr == nil { firstErr = err } } - return ctx, firstErr + return firstErr } func callHookSlice2( @@ -94,13 +92,11 @@ func callAfterSelectHookSlice( return callHookSlice2(ctx, slice, ptr, callAfterSelectHook) } -func callBeforeInsertHook(ctx context.Context, v reflect.Value) (context.Context, error) { +func callBeforeInsertHook(ctx context.Context, v reflect.Value) error { return v.Interface().(schema.BeforeInsertHook).BeforeInsert(ctx) } -func callBeforeInsertHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) (context.Context, error) { +func callBeforeInsertHookSlice(ctx context.Context, slice reflect.Value, ptr bool) error { return callHookSlice(ctx, slice, ptr, callBeforeInsertHook) } @@ -108,19 +104,15 @@ func callAfterInsertHook(ctx context.Context, v reflect.Value) error { return v.Interface().(schema.AfterInsertHook).AfterInsert(ctx) } -func callAfterInsertHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) error { +func callAfterInsertHookSlice(ctx context.Context, slice reflect.Value, ptr bool) error { return callHookSlice2(ctx, slice, ptr, callAfterInsertHook) } -func callBeforeUpdateHook(ctx context.Context, v reflect.Value) (context.Context, error) { +func callBeforeUpdateHook(ctx context.Context, v reflect.Value) error { return v.Interface().(schema.BeforeUpdateHook).BeforeUpdate(ctx) } -func callBeforeUpdateHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) (context.Context, error) { +func callBeforeUpdateHookSlice(ctx context.Context, slice reflect.Value, ptr bool) error { return callHookSlice(ctx, slice, ptr, callBeforeUpdateHook) } @@ -128,19 +120,15 @@ func callAfterUpdateHook(ctx context.Context, v reflect.Value) error { return v.Interface().(schema.AfterUpdateHook).AfterUpdate(ctx) } -func callAfterUpdateHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) error { +func callAfterUpdateHookSlice(ctx context.Context, slice reflect.Value, ptr bool) error { return callHookSlice2(ctx, slice, ptr, callAfterUpdateHook) } -func callBeforeDeleteHook(ctx context.Context, v reflect.Value) (context.Context, error) { +func callBeforeDeleteHook(ctx context.Context, v reflect.Value) error { return v.Interface().(schema.BeforeDeleteHook).BeforeDelete(ctx) } -func callBeforeDeleteHookSlice( - ctx context.Context, slice reflect.Value, ptr bool, -) (context.Context, error) { +func callBeforeDeleteHookSlice(ctx context.Context, slice reflect.Value, ptr bool) error { return callHookSlice(ctx, slice, ptr, callBeforeDeleteHook) } diff --git a/model_table_slice.go b/model_table_slice.go index 3ae8ee2eb..d06c86045 100644 --- a/model_table_slice.go +++ b/model_table_slice.go @@ -106,11 +106,11 @@ func (m *sliceTableModel) AfterSelect(ctx context.Context) error { return nil } -func (m *sliceTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { +func (m *sliceTableModel) BeforeInsert(ctx context.Context) error { if m.table.HasBeforeInsertHook() { return callBeforeInsertHookSlice(ctx, m.slice, m.sliceOfPtr) } - return ctx, nil + return nil } func (m *sliceTableModel) AfterInsert(ctx context.Context) error { @@ -120,11 +120,11 @@ func (m *sliceTableModel) AfterInsert(ctx context.Context) error { return nil } -func (m *sliceTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { +func (m *sliceTableModel) BeforeUpdate(ctx context.Context) error { if m.table.HasBeforeUpdateHook() && !m.IsNil() { return callBeforeUpdateHookSlice(ctx, m.slice, m.sliceOfPtr) } - return ctx, nil + return nil } func (m *sliceTableModel) AfterUpdate(ctx context.Context) error { @@ -134,11 +134,11 @@ func (m *sliceTableModel) AfterUpdate(ctx context.Context) error { return nil } -func (m *sliceTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { +func (m *sliceTableModel) BeforeDelete(ctx context.Context) error { if m.table.HasBeforeDeleteHook() && !m.IsNil() { return callBeforeDeleteHookSlice(ctx, m.slice, m.sliceOfPtr) } - return ctx, nil + return nil } func (m *sliceTableModel) AfterDelete(ctx context.Context) error { diff --git a/model_table_struct.go b/model_table_struct.go index 1a0c35169..065247fb5 100644 --- a/model_table_struct.go +++ b/model_table_struct.go @@ -165,11 +165,11 @@ func (m *structTableModel) AfterSelect(ctx context.Context) error { return nil } -func (m *structTableModel) BeforeInsert(ctx context.Context) (context.Context, error) { +func (m *structTableModel) BeforeInsert(ctx context.Context) error { if m.table.HasBeforeInsertHook() { return callBeforeInsertHook(ctx, m.strct.Addr()) } - return ctx, nil + return nil } func (m *structTableModel) AfterInsert(ctx context.Context) error { @@ -179,11 +179,11 @@ func (m *structTableModel) AfterInsert(ctx context.Context) error { return nil } -func (m *structTableModel) BeforeUpdate(ctx context.Context) (context.Context, error) { +func (m *structTableModel) BeforeUpdate(ctx context.Context) error { if m.table.HasBeforeUpdateHook() && !m.IsNil() { return callBeforeUpdateHook(ctx, m.strct.Addr()) } - return ctx, nil + return nil } func (m *structTableModel) AfterUpdate(ctx context.Context) error { @@ -193,11 +193,11 @@ func (m *structTableModel) AfterUpdate(ctx context.Context) error { return nil } -func (m *structTableModel) BeforeDelete(ctx context.Context) (context.Context, error) { +func (m *structTableModel) BeforeDelete(ctx context.Context) error { if m.table.HasBeforeDeleteHook() && !m.IsNil() { return callBeforeDeleteHook(ctx, m.strct.Addr()) } - return ctx, nil + return nil } func (m *structTableModel) AfterDelete(ctx context.Context) error { diff --git a/query_delete.go b/query_delete.go index a00c0de42..77bc8bcb1 100644 --- a/query_delete.go +++ b/query_delete.go @@ -207,23 +207,37 @@ func (q *DeleteQuery) ForceDelete(ctx context.Context, dest ...interface{}) (res } func (q *DeleteQuery) beforeDeleteQueryHook(ctx context.Context) error { - if q.table == nil { + if q.tableModel == nil { return nil } - hook, ok := q.table.ZeroIface.(BeforeDeleteQueryHook) - if !ok { - return nil + + if err := q.tableModel.BeforeDelete(ctx); err != nil { + return err } - return hook.BeforeDeleteQuery(ctx, q) + + // if hook, ok := q.table.ZeroIface.(BeforeDeleteQueryHook); ok { + // if err := hook.BeforeDeleteQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } func (q *DeleteQuery) afterDeleteQueryHook(ctx context.Context) error { - if q.table == nil { + if q.tableModel == nil { return nil } - hook, ok := q.table.ZeroIface.(AfterDeleteQueryHook) - if !ok { - return nil + + if err := q.tableModel.AfterDelete(ctx); err != nil { + return err } - return hook.AfterDeleteQuery(ctx, q) + + // if hook, ok := q.table.ZeroIface.(AfterDeleteQueryHook); ok { + // if err := hook.AfterDeleteQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } diff --git a/query_insert.go b/query_insert.go index 6bf501540..f609d9556 100644 --- a/query_insert.go +++ b/query_insert.go @@ -464,25 +464,39 @@ func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (res Result } func (q *InsertQuery) beforeInsertQueryHook(ctx context.Context) error { - if q.table == nil { + if q.tableModel == nil { return nil } - hook, ok := q.table.ZeroIface.(BeforeInsertQueryHook) - if !ok { - return nil + + if err := q.tableModel.BeforeInsert(ctx); err != nil { + return err } - return hook.BeforeInsertQuery(ctx, q) + + // if hook, ok := q.table.ZeroIface.(BeforeInsertQueryHook); ok { + // if err := hook.BeforeInsertQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } func (q *InsertQuery) afterInsertQueryHook(ctx context.Context) error { - if q.table == nil { + if q.tableModel == nil { return nil } - hook, ok := q.table.ZeroIface.(AfterInsertQueryHook) - if !ok { - return nil + + if err := q.tableModel.AfterInsert(ctx); err != nil { + return err } - return hook.AfterInsertQuery(ctx, q) + + // if hook, ok := q.table.ZeroIface.(AfterInsertQueryHook); ok { + // if err := hook.AfterInsertQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error { diff --git a/query_select.go b/query_select.go index 2d6ff64a1..1db929ee5 100644 --- a/query_select.go +++ b/query_select.go @@ -692,22 +692,33 @@ func (q *SelectQuery) beforeSelectQueryHook(ctx context.Context) error { if q.table == nil { return nil } - hook, ok := q.table.ZeroIface.(BeforeSelectQueryHook) - if !ok { - return nil - } - return hook.BeforeSelectQuery(ctx, q) + + // hook, ok := q.table.ZeroIface.(BeforeSelectQueryHook) + // if ok { + // if err := hook.BeforeSelectQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } func (q *SelectQuery) afterSelectQueryHook(ctx context.Context) error { - if q.table == nil { + if q.tableModel == nil { return nil } - hook, ok := q.table.ZeroIface.(AfterSelectQueryHook) - if !ok { - return nil + + if err := q.tableModel.AfterSelect(ctx); err != nil { + return err } - return hook.AfterSelectQuery(ctx, q) + + // if hook, ok := q.table.ZeroIface.(AfterSelectQueryHook); ok { + // if err := hook.AfterSelectQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } func (q *SelectQuery) Count(ctx context.Context) (int, error) { diff --git a/query_update.go b/query_update.go index a051f7128..1fbe58052 100644 --- a/query_update.go +++ b/query_update.go @@ -321,23 +321,37 @@ func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (res Result } func (q *UpdateQuery) beforeUpdateQueryHook(ctx context.Context) error { - if q.table == nil { + if q.tableModel == nil { return nil } - hook, ok := q.table.ZeroIface.(BeforeUpdateQueryHook) - if !ok { - return nil + + if err := q.tableModel.BeforeUpdate(ctx); err != nil { + return err } - return hook.BeforeUpdateQuery(ctx, q) + + // if hook, ok := q.table.ZeroIface.(BeforeUpdateQueryHook); ok { + // if err := hook.BeforeUpdateQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } func (q *UpdateQuery) afterUpdateQueryHook(ctx context.Context) error { - if q.table == nil { + if q.tableModel == nil { return nil } - hook, ok := q.table.ZeroIface.(AfterUpdateQueryHook) - if !ok { - return nil + + if err := q.tableModel.AfterUpdate(ctx); err != nil { + return err } - return hook.AfterUpdateQuery(ctx, q) + + // if hook, ok := q.table.ZeroIface.(AfterUpdateQueryHook); ok { + // if err := hook.AfterUpdateQuery(ctx, q); err != nil { + // return err + // } + // } + + return nil } diff --git a/schema/hook.go b/schema/hook.go index a363fac13..f2e2f7cbe 100644 --- a/schema/hook.go +++ b/schema/hook.go @@ -30,7 +30,7 @@ var afterSelectHookType = reflect.TypeOf((*AfterSelectHook)(nil)).Elem() //------------------------------------------------------------------------------ type BeforeInsertHook interface { - BeforeInsert(context.Context) (context.Context, error) + BeforeInsert(context.Context) error } var beforeInsertHookType = reflect.TypeOf((*BeforeInsertHook)(nil)).Elem() @@ -46,7 +46,7 @@ var afterInsertHookType = reflect.TypeOf((*AfterInsertHook)(nil)).Elem() //------------------------------------------------------------------------------ type BeforeUpdateHook interface { - BeforeUpdate(context.Context) (context.Context, error) + BeforeUpdate(context.Context) error } var beforeUpdateHookType = reflect.TypeOf((*BeforeUpdateHook)(nil)).Elem() @@ -62,7 +62,7 @@ var afterUpdateHookType = reflect.TypeOf((*AfterUpdateHook)(nil)).Elem() //------------------------------------------------------------------------------ type BeforeDeleteHook interface { - BeforeDelete(context.Context) (context.Context, error) + BeforeDelete(context.Context) error } var beforeDeleteHookType = reflect.TypeOf((*BeforeDeleteHook)(nil)).Elem()