diff --git a/tools/gentool/gentool.go b/tools/gentool/gentool.go index 5fdf21c4..9f7178ce 100644 --- a/tools/gentool/gentool.go +++ b/tools/gentool/gentool.go @@ -28,6 +28,9 @@ const ( dbSQLServer DBType = "sqlserver" dbClickHouse DBType = "clickhouse" ) +const ( + defaultQueryPath = "./dao/query" +) // CmdParams is command line parameters type CmdParams struct { @@ -45,6 +48,32 @@ type CmdParams struct { FieldSignable bool `yaml:"fieldSignable"` // detect integer field's unsigned type, adjust generated data type } +func (c *CmdParams) revise() *CmdParams { + if c == nil { + return c + } + if c.DB == "" { + c.DB = string(dbMySQL) + } + if c.OutPath == "" { + c.OutPath = defaultQueryPath + } + if len(c.Tables) == 0 { + return c + } + + tableList := make([]string, 0, len(c.Tables)) + for _, tableName := range c.Tables { + _tableName := strings.TrimSpace(tableName) // trim leading and trailing space in tableName + if _tableName == "" { // skip empty tableName + continue + } + tableList = append(tableList, _tableName) + } + c.Tables = tableList + return c +} + // YamlConfig is yaml config struct type YamlConfig struct { Version string `yaml:"version"` // @@ -75,37 +104,36 @@ func connectDB(t DBType, dsn string) (*gorm.DB, error) { // genModels is gorm/gen generated models func genModels(g *gen.Generator, db *gorm.DB, tables []string) (models []interface{}, err error) { - var tablesList []string if len(tables) == 0 { // Execute tasks for all tables in the database - tablesList, err = db.Migrator().GetTables() + tables, err = db.Migrator().GetTables() if err != nil { return nil, fmt.Errorf("GORM migrator get all tables fail: %w", err) } - } else { - tablesList = tables } // Execute some data table tasks - models = make([]interface{}, len(tablesList)) - for i, tableName := range tablesList { + models = make([]interface{}, len(tables)) + for i, tableName := range tables { models[i] = g.GenerateModel(tableName) } return models, nil } -// loadConfigFile load config file from path -func loadConfigFile(path string) (*CmdParams, error) { +// parseCmdFromYaml parse cmd param from yaml +func parseCmdFromYaml(path string) *CmdParams { file, err := os.Open(path) if err != nil { - return nil, err + log.Fatalf("parseCmdFromYaml fail %s", err.Error()) + return nil } defer file.Close() // nolint var yamlConfig YamlConfig - if cmdErr := yaml.NewDecoder(file).Decode(&yamlConfig); cmdErr != nil { - return nil, cmdErr + if err = yaml.NewDecoder(file).Decode(&yamlConfig); err != nil { + log.Fatalf("parseCmdFromYaml fail %s", err.Error()) + return nil } - return yamlConfig.Database, nil + return yamlConfig.Database } // argParse is parser for cmd @@ -113,10 +141,10 @@ func argParse() *CmdParams { // choose is file or flag genPath := flag.String("c", "", "is path for gen.yml") dsn := flag.String("dsn", "", "consult[https://gorm.io/docs/connecting_to_the_database.html]") - db := flag.String("db", "mysql", "input mysql|postgres|sqlite|sqlserver|clickhouse. consult[https://gorm.io/docs/connecting_to_the_database.html]") + db := flag.String("db", string(dbMySQL), "input mysql|postgres|sqlite|sqlserver|clickhouse. consult[https://gorm.io/docs/connecting_to_the_database.html]") tableList := flag.String("tables", "", "enter the required data table or leave it blank") onlyModel := flag.Bool("onlyModel", false, "only generate models (without query file)") - outPath := flag.String("outPath", "./dao/query", "specify a directory for output") + outPath := flag.String("outPath", defaultQueryPath, "specify a directory for output") outFile := flag.String("outFile", "", "query code file name, default: gen.go") withUnitTest := flag.Bool("withUnitTest", false, "generate unit test for query code") modelPkgName := flag.String("modelPkgName", "", "generated model code's package name") @@ -125,14 +153,12 @@ func argParse() *CmdParams { fieldWithTypeTag := flag.Bool("fieldWithTypeTag", false, "generate field with gorm column type tag") fieldSignable := flag.Bool("fieldSignable", false, "detect integer field's unsigned type, adjust generated data type") flag.Parse() - var cmdParse CmdParams - if *genPath != "" { - if configFileParams, err := loadConfigFile(*genPath); err == nil && configFileParams != nil { - cmdParse = *configFileParams - } else if err != nil { - log.Fatalf("loadConfigFile fail %s", err.Error()) - } + + if *genPath != "" { //use yml config + return parseCmdFromYaml(*genPath) } + + var cmdParse CmdParams // cmd first if *dsn != "" { cmdParse.DSN = *dsn @@ -141,13 +167,7 @@ func argParse() *CmdParams { cmdParse.DB = *db } if *tableList != "" { - for _, tableName := range strings.Split(*tableList, ",") { - _tableName := strings.TrimSpace(tableName) // trim leading and trailing space in tableName - if _tableName == "" { // skip empty tableName - continue - } - cmdParse.Tables = append(cmdParse.Tables, _tableName) - } + cmdParse.Tables = strings.Split(*tableList, ",") } if *onlyModel { cmdParse.OnlyModel = true @@ -181,7 +201,7 @@ func argParse() *CmdParams { func main() { // cmdParse - config := argParse() + config := argParse().revise() if config == nil { log.Fatalln("parse config fail") }