Skip to content

Commit

Permalink
feat: 集成starter初始化与关闭
Browse files Browse the repository at this point in the history
  • Loading branch information
lirui2 committed Jul 5, 2024
1 parent 9cb9292 commit 1574f6d
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 43 deletions.
6 changes: 3 additions & 3 deletions kit/db_builder/option.go → builder/common.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package db_builder
package builder

type DBDriver string

Expand All @@ -7,7 +7,7 @@ const (
DBDriverPostgresL = "postgres"
)

type Options struct {
type DBOptions struct {
Driver DBDriver
User string
Password string
Expand All @@ -21,7 +21,7 @@ type Options struct {
MaxLifetime int // 最大连接时长
}

func (o *Options) SetDefault() {
func (o *DBOptions) SetDefault() {
if o.MaxIdleConn <= 0 {
o.MaxIdleConn = 10
}
Expand Down
10 changes: 2 additions & 8 deletions kit/db_builder/creator.go → builder/db.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package db_builder
package builder

import (
"fmt"
sparrow "github.com/lrayt/small-sparrow"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"time"
)

func CreateGormDB(options *Options) (*gorm.DB, error) {
func CreateGormDB(options *DBOptions) (*gorm.DB, error) {
options.SetDefault() // 设置默认值
var dialect gorm.Dialector
switch options.Driver {
Expand Down Expand Up @@ -43,11 +42,6 @@ func CreateGormDB(options *Options) (*gorm.DB, error) {
if dbErr != nil {
return nil, dbErr
}
sparrow.GRunEnv()
// 测试连接
if err := db.Ping(); err != nil {
return nil, err
}
db.SetMaxIdleConns(options.MaxIdleConn) // 设置空闲连接池中连接的最大数量
db.SetMaxOpenConns(options.MaxOpenConn) // 设置打开数据库连接的最大数量。
db.SetConnMaxLifetime(time.Minute * time.Duration(options.MaxLifetime)) // 设置了连接可复用的最大时间。
Expand Down
3 changes: 1 addition & 2 deletions core/abstract/starter.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package abstract

type Starter interface {
Init() error // 初始化
Run() error // 运行
Init() error
Close() error // 关闭
}
23 changes: 19 additions & 4 deletions core/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log"
"os"
"os/signal"
"reflect"
"syscall"
)

Expand Down Expand Up @@ -54,7 +55,7 @@ func WithConfigurator(provider abstract.ConfigProvider) Option {
}
}

func WithStarter(starters []abstract.Starter) Option {
func WithStarter(starters ...abstract.Starter) Option {
return func(app *Application) {
app.Starters = starters
}
Expand Down Expand Up @@ -95,7 +96,6 @@ func InitApp(appName, version string, options ...Option) error {
WithConfigurator(provider)(app)
}
}

// default logger
if app.LoggerProvider == nil {
if provider, err := log_manager.NewLocalFileLogProvider(app.Env); err != nil {
Expand All @@ -112,7 +112,13 @@ func SetupApp() {
errChan = make(chan error, 1)
signalChan = make(chan os.Signal, 1)
)

for _, starter := range app.Starters {
if err := starter.Init(); err != nil {
log.Fatalf("启动失败,err:%s\n", err.Error())
} else {
log.Printf("%s初始化成功\n", reflect.TypeOf(starter).String())
}
}
for _, provider := range app.Handlers {
if provider == nil {
continue
Expand All @@ -129,9 +135,18 @@ func SetupApp() {
case err := <-errChan:
log.Fatalf("服务启动异常,err:%v", err)
case <-signalChan:
// shutdown handler
for _, handler := range app.Handlers {
if err := handler.Shutdown(); err != nil {
log.Printf("shutdown server:%s\n", err.Error())
log.Printf("shutdown handler:%s\n", err.Error())
}
}
// close starter
for _, starter := range app.Starters {
if err := starter.Close(); err != nil {
log.Printf("%s close err: %s\n", reflect.TypeOf(starter).String(), err.Error())
} else {
log.Printf("%s closed\n", reflect.TypeOf(starter).String())
}
}
}
Expand Down
16 changes: 12 additions & 4 deletions core/global.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
package core

import "github.com/lrayt/small-sparrow/core/abstract"
import (
"github.com/lrayt/small-sparrow/core/abstract"
"github.com/lrayt/small-sparrow/core/runtime"
)

// GConfigs 全局配置
func GConfigs() abstract.ConfigProvider {
return app.ConfigProvider
}

// GRunEnv 运行环境
func GRunEnv() string {
return string(app.Env.RunEnv)
func IsProdEnv() bool {
return app.Env.RunEnv == runtime.RunProdEnv
}
func IsTestEnv() bool {
return app.Env.RunEnv == runtime.RunTestEnv
}
func IsLocalEnv() bool {
return app.Env.RunEnv == runtime.RunLocalEnv
}

func GWorkDir() {
Expand Down
54 changes: 42 additions & 12 deletions example/Internal/database/db_provider.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,53 @@
package database

import (
sparrow "github.com/lrayt/small-sparrow"
"github.com/lrayt/small-sparrow/kit/db_builder"
"context"
"database/sql"
"github.com/lrayt/small-sparrow/builder"
"github.com/lrayt/small-sparrow/core"
"gorm.io/gorm"
"time"
)

type DBProvider struct {
DB *gorm.DB
type DBManager struct {
GormDB *gorm.DB
}

func NewDBProvider() (*DBProvider, error) {
options := new(db_builder.Options)
if err := sparrow.GConfigs().PackConf("database.scp-db", options); err != nil {
return nil, err
func (p *DBManager) Init() error {
var (
err error
options = new(builder.DBOptions)
)
// cfg
err = core.GConfigs().PackConf("database.scp-db", options)
if err != nil {
return err
}
if db, err := db_builder.CreateGormDB(options); err != nil {
return nil, err
} else {
return &DBProvider{DB: db}, nil
// gorm db
p.GormDB, err = builder.CreateGormDB(options)
if err != nil {
return err
}
// sql db
var sqlDB *sql.DB
sqlDB, err = p.GormDB.DB()
if err != nil {
return err
}
// ping
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
return sqlDB.PingContext(ctx)
}

func (p DBManager) Close() error {
db, err := p.GormDB.DB()
if err != nil {
return err
}
return db.Close()
}

func NewDBManager() *DBManager {
return new(DBManager)
}
6 changes: 4 additions & 2 deletions example/cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"github.com/lrayt/small-sparrow/core"
"github.com/lrayt/small-sparrow/example/Internal/database"
"github.com/lrayt/small-sparrow/example/app/handler"
"log"
"path/filepath"
Expand All @@ -14,15 +15,16 @@ var (

type ExampleServer struct {
HttpHandler *handler.HttpHandler
dbm *database.DBManager
}

func NewExampleServer(httpHandler *handler.HttpHandler) (*ExampleServer, error) {
func NewExampleServer(httpHandler *handler.HttpHandler, dbm *database.DBManager) (*ExampleServer, error) {
rootPath, pathErr := filepath.Abs("")
if pathErr != nil {
log.Fatalf("获取项目工作路径失败,err:%s\n", pathErr.Error())
}
rootPath = filepath.Join(rootPath, "example")
if err := core.InitApp(AppName, Version, core.WithHandler(httpHandler), core.WithWorkerDir(rootPath)); err != nil {
if err := core.InitApp(AppName, Version, core.WithHandler(httpHandler), core.WithWorkerDir(rootPath), core.WithStarter(dbm)); err != nil {
return nil, err
}
return &ExampleServer{HttpHandler: httpHandler}, nil
Expand Down
13 changes: 6 additions & 7 deletions example/cmd/server/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ package main

import (
"github.com/google/wire"
"github.com/lrayt/small-sparrow/example/Internal/database"
"github.com/lrayt/small-sparrow/example/app/handler"
)

//var InternalProvider = wire.NewSet(
// database.NewDBProvider,
// database.NewCacheProvider,
// http_manager.NewGinHttpProvider,
// message.NewMQProvider,
//)
var InternalProvider = wire.NewSet(
database.NewDBManager,
)

//
//// DaoProvider 数据库操作
//var DaoProvider = wire.NewSet(
Expand Down Expand Up @@ -64,5 +63,5 @@ var HandlerProvider = wire.NewSet(
)

func InitExampleServer() (*ExampleServer, func(), error) {
panic(wire.Build(HandlerProvider, NewExampleServer))
panic(wire.Build(InternalProvider, HandlerProvider, NewExampleServer))
}
6 changes: 5 additions & 1 deletion example/cmd/server/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1574f6d

Please sign in to comment.