Skip to content

Commit

Permalink
Add MarkDone and MarkUndone (uptrace#70)
Browse files Browse the repository at this point in the history
* Add MarkDone and MarkUndone

* Rename WithMigrationDryRun to WithoutMigrationFunc

* Add MigrationsWithStatus

* Rename WithoutMigrationFunc to WithNopMigration
  • Loading branch information
vmihailenco authored Jul 12, 2021
1 parent b33abf8 commit 1287e6f
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 165 deletions.
26 changes: 13 additions & 13 deletions example/migrate/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ func newDBCommand(migrator *migrate.Migrator) *cli.Command {
if err != nil {
return err
}
if group.ID == 0 {
fmt.Printf("there are no new migrations to run\n")
if group.IsZero() {
fmt.Printf("there are no new migrations to run (database is up to date)\n")
return nil
}
fmt.Printf("migrated to %s\n", group)
Expand All @@ -74,7 +74,7 @@ func newDBCommand(migrator *migrate.Migrator) *cli.Command {
if err != nil {
return err
}
if group.ID == 0 {
if group.IsZero() {
fmt.Printf("there are no groups to roll back\n")
return nil
}
Expand Down Expand Up @@ -124,29 +124,29 @@ func newDBCommand(migrator *migrate.Migrator) *cli.Command {
Name: "status",
Usage: "print migrations status",
Action: func(c *cli.Context) error {
status, err := migrator.Status(c.Context)
ms, err := migrator.MigrationsWithStatus(c.Context)
if err != nil {
return err
}
fmt.Printf("migrations: %s\n", status.Migrations)
fmt.Printf("new migrations: %s\n", status.NewMigrations)
fmt.Printf("last group: %s\n", status.LastGroup)
fmt.Printf("migrations: %s\n", ms)
fmt.Printf("unapplied migrations: %s\n", ms.Unapplied())
fmt.Printf("last migration group: %s\n", ms.LastGroup())
return nil
},
},
{
Name: "mark_completed",
Usage: "mark migrations as completed without actually running them",
Name: "mark_applied",
Usage: "mark migrations as applied without actually running them",
Action: func(c *cli.Context) error {
group, err := migrator.MarkCompleted(c.Context)
group, err := migrator.Migrate(c.Context, migrate.WithNopMigration())
if err != nil {
return err
}
if group.ID == 0 {
fmt.Printf("there are no new migrations to mark as completed\n")
if group.IsZero() {
fmt.Printf("there are no new migrations to mark as applied\n")
return nil
}
fmt.Printf("marked as completed %s\n", group)
fmt.Printf("marked as applied %s\n", group)
return nil
},
},
Expand Down
102 changes: 101 additions & 1 deletion migrate/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"io/fs"
"sort"
"strings"
"time"

Expand All @@ -28,6 +29,10 @@ func (m *Migration) String() string {
return m.Name
}

func (m *Migration) IsApplied() bool {
return m.ID > 0
}

type MigrationFunc func(ctx context.Context, db *bun.DB) error

func NewSQLMigrationFunc(fsys fs.FS, name string) MigrationFunc {
Expand Down Expand Up @@ -136,13 +141,72 @@ func (ms MigrationSlice) String() string {
return sb.String()
}

// Applied returns applied migrations in descending order
// (the order is important and is used in Rollback).
func (ms MigrationSlice) Applied() MigrationSlice {
var applied MigrationSlice
for i := range ms {
if ms[i].IsApplied() {
applied = append(applied, ms[i])
}
}
sortDesc(applied)
return applied
}

// Unapplied returns unapplied migrations in ascending order
// (the order is important and is used in Migrate).
func (ms MigrationSlice) Unapplied() MigrationSlice {
var unapplied MigrationSlice
for i := range ms {
if !ms[i].IsApplied() {
unapplied = append(unapplied, ms[i])
}
}
sortAsc(unapplied)
return unapplied
}

// LastGroupID returns the last applied migration group id.
// The id is 0 when there are no migration groups.
func (ms MigrationSlice) LastGroupID() int64 {
var lastGroupID int64
for i := range ms {
groupID := ms[i].GroupID
if groupID != 0 && groupID > lastGroupID {
lastGroupID = groupID
}
}
return lastGroupID
}

// LastGroup returns the last applied migration group.
func (ms MigrationSlice) LastGroup() *MigrationGroup {
group := &MigrationGroup{
ID: ms.LastGroupID(),
}
if group.ID == 0 {
return group
}
for i := range ms {
if ms[i].GroupID == group.ID {
group.Migrations = append(group.Migrations, ms[i])
}
}
return group
}

type MigrationGroup struct {
ID int64
Migrations MigrationSlice
}

func (g *MigrationGroup) IsZero() bool {
return g.ID == 0 && len(g.Migrations) == 0
}

func (g *MigrationGroup) String() string {
if g.ID == 0 && len(g.Migrations) == 0 {
if g.IsZero() {
return "nil"
}
return fmt.Sprintf("group #%d (%s)", g.ID, g.Migrations)
Expand All @@ -153,3 +217,39 @@ type MigrationFile struct {
FilePath string
Content string
}

//------------------------------------------------------------------------------

type migrationConfig struct {
nop bool
}

func newMigrationConfig(opts []MigrationOption) *migrationConfig {
cfg := new(migrationConfig)
for _, opt := range opts {
opt(cfg)
}
return cfg
}

type MigrationOption func(cfg *migrationConfig)

func WithNopMigration() MigrationOption {
return func(cfg *migrationConfig) {
cfg.nop = true
}
}

//------------------------------------------------------------------------------

func sortAsc(ms MigrationSlice) {
sort.Slice(ms, func(i, j int) bool {
return ms[i].Name < ms[j].Name
})
}

func sortDesc(ms MigrationSlice) {
sort.Slice(ms, func(i, j int) bool {
return ms[i].Name > ms[j].Name
})
}
7 changes: 1 addition & 6 deletions migrate/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"path/filepath"
"regexp"
"runtime"
"sort"
"strings"
)

Expand Down Expand Up @@ -37,11 +36,7 @@ func NewMigrations(opts ...MigrationsOption) *Migrations {
func (m *Migrations) Sorted() MigrationSlice {
migrations := make(MigrationSlice, len(m.ms))
copy(migrations, m.ms)

sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Name < migrations[j].Name
})

sortAsc(migrations)
return migrations
}

Expand Down
Loading

0 comments on commit 1287e6f

Please sign in to comment.