Skip to content

Commit

Permalink
cmd: Generate packages in parallel (sqlc-dev#2026)
Browse files Browse the repository at this point in the history
* compiler: Speed up generate

* wasm: Load serialized modules using singleflight
  • Loading branch information
kyleconroy authored Jan 19, 2023
1 parent c4ceb0e commit d64a68b
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 70 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/pganalyze/pg_query_go/v2 v2.2.0
github.com/spf13/cobra v1.6.1
github.com/spf13/pflag v1.0.5
golang.org/x/sync v0.1.0
google.golang.org/protobuf v1.28.1
gopkg.in/yaml.v3 v3.0.1
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
Expand Down
3 changes: 1 addition & 2 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
rootCmd.SetIn(stdin)
rootCmd.SetOut(stdout)
rootCmd.SetErr(stderr)
rootCmd.SilenceErrors = true

ctx := context.Background()
if debug.Debug.Trace != "" {
Expand All @@ -55,9 +56,7 @@ func Do(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer) int
ctx = tracectx
defer cleanup()
}

if err := rootCmd.ExecuteContext(ctx); err != nil {
fmt.Fprintf(stderr, "%v\n", err)
if exitError, ok := err.(*exec.ExitError); ok {
return exitError.ExitCode()
} else {
Expand Down
135 changes: 81 additions & 54 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ import (
"io"
"os"
"path/filepath"
"runtime"
"runtime/trace"
"strings"
"sync"

"golang.org/x/sync/errgroup"

"github.com/kyleconroy/sqlc/internal/codegen/golang"
"github.com/kyleconroy/sqlc/internal/codegen/json"
Expand Down Expand Up @@ -159,71 +163,94 @@ func Generate(ctx context.Context, e Env, dir, filename string, stderr io.Writer
}
}

for _, sql := range pairs {
combo := config.Combine(*conf, sql.SQL)
if sql.Plugin != nil {
combo.Codegen = *sql.Plugin
}
var m sync.Mutex
grp, gctx := errgroup.WithContext(ctx)
grp.SetLimit(runtime.GOMAXPROCS(0))

// TODO: This feels like a hack that will bite us later
joined := make([]string, 0, len(sql.Schema))
for _, s := range sql.Schema {
joined = append(joined, filepath.Join(dir, s))
}
sql.Schema = joined
stderrs := make([]bytes.Buffer, len(pairs))

joined = make([]string, 0, len(sql.Queries))
for _, q := range sql.Queries {
joined = append(joined, filepath.Join(dir, q))
}
sql.Queries = joined
for i, pair := range pairs {
sql := pair
errout := &stderrs[i]

var name, lang string
parseOpts := opts.Parser{
Debug: debug.Debug,
}
grp.Go(func() error {
combo := config.Combine(*conf, sql.SQL)
if sql.Plugin != nil {
combo.Codegen = *sql.Plugin
}

switch {
case sql.Gen.Go != nil:
name = combo.Go.Package
lang = "golang"
// TODO: This feels like a hack that will bite us later
joined := make([]string, 0, len(sql.Schema))
for _, s := range sql.Schema {
joined = append(joined, filepath.Join(dir, s))
}
sql.Schema = joined

case sql.Plugin != nil:
lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin)
name = sql.Plugin.Plugin
}
joined = make([]string, 0, len(sql.Queries))
for _, q := range sql.Queries {
joined = append(joined, filepath.Join(dir, q))
}
sql.Queries = joined

packageRegion := trace.StartRegion(ctx, "package")
trace.Logf(ctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)
var name, lang string
parseOpts := opts.Parser{
Debug: debug.Debug,
}

result, failed := parse(ctx, name, dir, sql.SQL, combo, parseOpts, stderr)
if failed {
packageRegion.End()
errored = true
break
}
switch {
case sql.Gen.Go != nil:
name = combo.Go.Package
lang = "golang"

out, resp, err := codegen(ctx, combo, sql, result)
if err != nil {
fmt.Fprintf(stderr, "# package %s\n", name)
fmt.Fprintf(stderr, "error generating code: %s\n", err)
errored = true
packageRegion.End()
continue
}
case sql.Plugin != nil:
lang = fmt.Sprintf("process:%s", sql.Plugin.Plugin)
name = sql.Plugin.Plugin
}

files := map[string]string{}
for _, file := range resp.Files {
files[file.Name] = string(file.Contents)
}
for n, source := range files {
filename := filepath.Join(dir, out, n)
output[filename] = source
}
packageRegion.End()
}
packageRegion := trace.StartRegion(gctx, "package")
trace.Logf(gctx, "", "name=%s dir=%s plugin=%s", name, dir, lang)

result, failed := parse(gctx, name, dir, sql.SQL, combo, parseOpts, errout)
if failed {
packageRegion.End()
errored = true
return nil
}

out, resp, err := codegen(gctx, combo, sql, result)
if err != nil {
fmt.Fprintf(errout, "# package %s\n", name)
fmt.Fprintf(errout, "error generating code: %s\n", err)
errored = true
packageRegion.End()
return nil
}

files := map[string]string{}
for _, file := range resp.Files {
files[file.Name] = string(file.Contents)
}

m.Lock()
for n, source := range files {
filename := filepath.Join(dir, out, n)
output[filename] = source
}
m.Unlock()

packageRegion.End()
return nil
})
}
if err := grp.Wait(); err != nil {
return nil, err
}
if errored {
for i, _ := range stderrs {
if _, err := io.Copy(stderr, &stderrs[i]); err != nil {
return nil, err
}
}
return nil, fmt.Errorf("errored")
}
return output, nil
Expand Down
47 changes: 33 additions & 14 deletions internal/ext/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

wasmtime "github.com/bytecodealliance/wasmtime-go/v3"
"golang.org/x/sync/singleflight"

"github.com/kyleconroy/sqlc/internal/info"
"github.com/kyleconroy/sqlc/internal/plugin"
Expand Down Expand Up @@ -49,6 +50,8 @@ type Runner struct {
SHA256 string
}

var flight singleflight.Group

// Verify the provided sha256 is valid.
func (r *Runner) parseChecksum() (string, error) {
if r.SHA256 == "" {
Expand All @@ -58,6 +61,24 @@ func (r *Runner) parseChecksum() (string, error) {
}

func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) {
expected, err := r.parseChecksum()
if err != nil {
return nil, err
}
value, err, _ := flight.Do(expected, func() (interface{}, error) {
return r.loadSerializedModule(ctx, engine)
})
if err != nil {
return nil, err
}
data, ok := value.([]byte)
if !ok {
return nil, fmt.Errorf("returned value was not a byte slice")
}
return wasmtime.NewModuleDeserialize(engine, data)
}

func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine) ([]byte, error) {
expected, err := r.parseChecksum()
if err != nil {
return nil, err
Expand All @@ -80,7 +101,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
if err != nil {
return nil, err
}
return wasmtime.NewModuleDeserialize(engine, data)
return data, nil
}

wmod, err := r.loadWASM(ctx, cache, expected)
Expand All @@ -95,21 +116,19 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
return nil, fmt.Errorf("define wasi: %w", err)
}

if staterr != nil {
err := os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
}
err = os.Mkdir(pluginDir, 0755)
if err != nil && !os.IsExist(err) {
return nil, fmt.Errorf("mkdirall: %w", err)
}
out, err := module.Serialize()
if err != nil {
return nil, fmt.Errorf("serialize: %w", err)
}
if err := os.WriteFile(modPath, out, 0444); err != nil {
return nil, fmt.Errorf("cache wasm: %w", err)
}

return module, nil
return out, nil
}

func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) {
Expand Down

0 comments on commit d64a68b

Please sign in to comment.