Skip to content

Commit

Permalink
Tweak option names
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed May 8, 2021
1 parent 80d8a21 commit 9b4a0bf
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 22 deletions.
90 changes: 70 additions & 20 deletions fixture/fixture.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,39 @@ import (
"gopkg.in/yaml.v3"
)

type ConfigOption func(f *Loader)
type LoaderOption func(f *Loader)

func WithTruncate() ConfigOption {
func WithDrop() LoaderOption {
return func(f *Loader) {
f.truncateTable = true
f.truncatedTables = make(map[string]struct{})
if f.truncateTables {
panic("don't use WithDrop together with WithTruncate")
}
f.dropTables = true
f.seenTables = make(map[string]struct{})
}
}

func WithTruncate() LoaderOption {
return func(f *Loader) {
if f.truncateTables {
panic("don't use WithTruncate together with WithDrop")
}
f.truncateTables = true
f.seenTables = make(map[string]struct{})
}
}

type Loader struct {
db *bun.DB

truncateTable bool
truncatedTables map[string]struct{}
dropTables bool
truncateTables bool
seenTables map[string]struct{}

modelRows map[string]map[string]interface{}
}

func NewLoader(db *bun.DB, opts ...ConfigOption) *Loader {
func NewLoader(db *bun.DB, opts ...LoaderOption) *Loader {
f := &Loader{
db: db,

Expand All @@ -53,7 +67,7 @@ func (f *Loader) Get(model, rowID string) (interface{}, error) {

row, ok := rows[rowID]
if !ok {
return nil, fmt.Errorf("fixture: unknown row=%q in model=%q", row, model)
return nil, fmt.Errorf("fixture: unknown row=%q for model=%q", row, model)
}

return row, nil
Expand Down Expand Up @@ -101,7 +115,17 @@ func (f *Loader) load(ctx context.Context, fsys fs.FS, name string) error {
func (f *Loader) addFixture(ctx context.Context, fixture *Fixture) error {
table := f.db.Dialect().Tables().ByModel(fixture.Model)
if table == nil {
return fmt.Errorf("fixture: can't find model=%q", fixture.Model)
return fmt.Errorf("fixture: can't find table=%q (use db.RegisterTable)", fixture.Model)
}

if f.dropTables {
if err := f.dropTable(ctx, table); err != nil {
return err
}
} else if f.truncateTables {
if err := f.truncateTable(ctx, table); err != nil {
return err
}
}

for _, row := range fixture.Rows {
Expand Down Expand Up @@ -150,17 +174,6 @@ func (f *Loader) addRow(ctx context.Context, table *schema.Table, row row) error
}
}

if f.truncateTable {
if _, ok := f.truncatedTables[table.Name]; !ok {
if _, err := f.db.NewTruncateTable().
TableExpr(string(table.SQLName)).
Exec(ctx); err != nil {
return err
}
f.truncatedTables[table.Name] = struct{}{}
}
}

iface := strct.Addr().Interface()
if _, err := f.db.NewInsert().
Model(iface).
Expand All @@ -186,6 +199,43 @@ func (f *Loader) addRow(ctx context.Context, table *schema.Table, row row) error
return nil
}

func (f *Loader) dropTable(ctx context.Context, table *schema.Table) error {
if _, ok := f.seenTables[table.Name]; ok {
return nil
}
f.seenTables[table.Name] = struct{}{}

if _, err := f.db.NewDropTable().
Model(table.ZeroIface).
IfExists().
Exec(ctx); err != nil {
return err
}

if _, err := f.db.NewCreateTable().
Model(table.ZeroIface).
Exec(ctx); err != nil {
return err
}

return nil
}

func (f *Loader) truncateTable(ctx context.Context, table *schema.Table) error {
if _, ok := f.seenTables[table.Name]; ok {
return nil
}
f.seenTables[table.Name] = struct{}{}

if _, err := f.db.NewTruncateTable().
Model(table.ZeroIface).
Exec(ctx); err != nil {
return err
}

return nil
}

func (f *Loader) eval(templ string) (string, error) {
tpl, err := template.New("").Parse(templ)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/dbtest/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func createTestSchema(t *testing.T, db *bun.DB) {
}

func loadTestData(t *testing.T, db *bun.DB) {
loader := fixture.NewLoader(db, fixture.WithTruncate())
loader := fixture.NewLoader(db)
err := loader.Load(context.TODO(), os.DirFS("testdata"), "fixture.yaml")
require.NoError(t, err)
}
2 changes: 1 addition & 1 deletion migrate/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Command struct {

type MigratorOption func(m *Migrator)

func AutoDiscover() MigratorOption {
func WithAutoDiscover() MigratorOption {
return func(m *Migrator) {
m.autoDiscover = true
}
Expand Down
5 changes: 5 additions & 0 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *Selec
// - RelationName.column_name,
// - RelationName._ to join relation without selecting relation columns.
func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery {
if q.tableModel == nil {
q.setErr(errModelNil)
return q
}

var fn func(*SelectQuery) *SelectQuery

if len(apply) == 1 {
Expand Down

0 comments on commit 9b4a0bf

Please sign in to comment.