Skip to content

Commit

Permalink
feat(plugin): add new task-batching plugin (ovh#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRebus authored Oct 28, 2024
1 parent 4ad19c6 commit dd6039e
Show file tree
Hide file tree
Showing 13 changed files with 984 additions and 38 deletions.
52 changes: 15 additions & 37 deletions api/handler/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package handler

import (
"github.com/gin-gonic/gin"
"github.com/juju/errors"
"github.com/loopfz/gadgeto/zesty"

"github.com/ovh/utask"
"github.com/ovh/utask/models/task"
"github.com/ovh/utask/models/tasktemplate"
"github.com/ovh/utask/pkg/batch"
"github.com/ovh/utask/pkg/metadata"
"github.com/ovh/utask/pkg/taskutils"
"github.com/ovh/utask/pkg/utils"
)

Expand All @@ -34,11 +32,6 @@ func CreateBatch(c *gin.Context, in *createBatchIn) (*task.Batch, error) {

metadata.AddActionMetadata(c, metadata.TemplateName, in.TemplateName)

tt, err := tasktemplate.LoadFromName(dbp, in.TemplateName)
if err != nil {
return nil, err
}

if err := utils.ValidateTags(in.Tags); err != nil {
return nil, err
}
Expand All @@ -49,45 +42,30 @@ func CreateBatch(c *gin.Context, in *createBatchIn) (*task.Batch, error) {

b, err := task.CreateBatch(dbp)
if err != nil {
dbp.Rollback()
_ = dbp.Rollback()
return nil, err
}

metadata.AddActionMetadata(c, metadata.BatchID, b.PublicID)

for _, inp := range in.Inputs {
input, err := conjMap(in.CommonInput, inp)
if err != nil {
dbp.Rollback()
return nil, err
}

_, err = taskutils.CreateTask(c, dbp, tt, in.WatcherUsernames, in.WatcherGroups, []string{}, []string{}, input, b, in.Comment, nil, in.Tags)
if err != nil {
dbp.Rollback()
return nil, err
}
_, err = batch.Populate(c, b, dbp, batch.TaskArgs{
TemplateName: in.TemplateName,
Inputs: in.Inputs,
CommonInput: in.CommonInput,
Comment: in.Comment,
WatcherUsernames: in.WatcherUsernames,
WatcherGroups: in.WatcherGroups,
Tags: in.Tags,
})
if err != nil {
_ = dbp.Rollback()
return nil, err
}

if err := dbp.Commit(); err != nil {
dbp.Rollback()
_ = dbp.Rollback()
return nil, err
}

return b, nil
}

func conjMap(common, particular map[string]interface{}) (map[string]interface{}, error) {
conj := make(map[string]interface{})
for key, value := range particular {
conj[key] = value
}

for key, value := range common {
if _, ok := conj[key]; ok {
return nil, errors.NewBadRequest(nil, "Conflicting keys in input maps")
}
conj[key] = value
}
return conj, nil
}
5 changes: 4 additions & 1 deletion engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/ovh/utask/pkg/jsonschema"
"github.com/ovh/utask/pkg/metadata"
"github.com/ovh/utask/pkg/now"
pluginbatch "github.com/ovh/utask/pkg/plugins/builtin/batch"
"github.com/ovh/utask/pkg/taskutils"
"github.com/ovh/utask/pkg/utils"
)
Expand Down Expand Up @@ -524,7 +525,9 @@ forLoop:
if mapStatus[status] {
if status == resolution.StateWaiting && recheckWaiting {
for name, s := range res.Steps {
if s.State == step.StateWaiting {
// Steps using the batch plugin shouldn't be run again when WAITING. Running them second time
// may lead to a race condition when the last task of a sub-batch tries to resume its parent
if s.State == step.StateWaiting && s.Action.Type != pluginbatch.Plugin.PluginName() {
delete(executedSteps, name)
}
}
Expand Down
123 changes: 123 additions & 0 deletions engine/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"

"github.com/Masterminds/squirrel"
"github.com/juju/errors"
"github.com/loopfz/gadgeto/zesty"
"github.com/maxatome/go-testdeep/td"
Expand All @@ -23,6 +25,7 @@ import (
"github.com/ovh/utask/api"
"github.com/ovh/utask/db"
"github.com/ovh/utask/db/pgjuju"
"github.com/ovh/utask/db/sqlgenerator"
"github.com/ovh/utask/engine"
"github.com/ovh/utask/engine/functions"
functionrunner "github.com/ovh/utask/engine/functions/runner"
Expand All @@ -36,6 +39,7 @@ import (
compress "github.com/ovh/utask/pkg/compress/init"
"github.com/ovh/utask/pkg/now"
"github.com/ovh/utask/pkg/plugins"
pluginbatch "github.com/ovh/utask/pkg/plugins/builtin/batch"
plugincallback "github.com/ovh/utask/pkg/plugins/builtin/callback"
"github.com/ovh/utask/pkg/plugins/builtin/echo"
"github.com/ovh/utask/pkg/plugins/builtin/script"
Expand Down Expand Up @@ -91,6 +95,7 @@ func TestMain(m *testing.M) {
step.RegisterRunner(echo.Plugin.PluginName(), echo.Plugin)
step.RegisterRunner(script.Plugin.PluginName(), script.Plugin)
step.RegisterRunner(pluginsubtask.Plugin.PluginName(), pluginsubtask.Plugin)
step.RegisterRunner(pluginbatch.Plugin.PluginName(), pluginbatch.Plugin)
step.RegisterRunner(plugincallback.Plugin.PluginName(), plugincallback.Plugin)

os.Exit(m.Run())
Expand Down Expand Up @@ -194,6 +199,21 @@ func templateFromYAML(dbp zesty.DBProvider, filename string) (*tasktemplate.Task
return tasktemplate.LoadFromName(dbp, tmpl.Name)
}

func listBatchTasks(dbp zesty.DBProvider, batchID int64) ([]string, error) {
query, params, err := sqlgenerator.PGsql.
Select("public_id").
From("task").
Where(squirrel.Eq{"id_batch": batchID}).
ToSql()
if err != nil {
return nil, err
}

var taskIDs []string
_, err = dbp.DB().Select(&taskIDs, query, params...)
return taskIDs, err
}

func TestSimpleTemplate(t *testing.T) {
input := map[string]interface{}{
"foo": "bar",
Expand Down Expand Up @@ -1370,3 +1390,106 @@ func TestB64RawEncodeDecode(t *testing.T) {
assert.Equal(t, "cmF3IG1lc3NhZ2U", output["a"])
assert.Equal(t, "raw message", output["b"])
}

func TestBatch(t *testing.T) {
dbp, err := zesty.NewDBProvider(utask.DBName)
require.Nil(t, err)

_, err = templateFromYAML(dbp, "batchedTask.yaml")
require.Nil(t, err)

_, err = templateFromYAML(dbp, "batch.yaml")
require.Nil(t, err)

res, err := createResolution("batch.yaml", map[string]interface{}{}, nil)
require.Nil(t, err, "failed to create resolution: %s", err)

res, err = runResolution(res)
require.Nil(t, err)
require.NotNil(t, res)
assert.Equal(t, resolution.StateWaiting, res.State)

for _, batchStepName := range []string{"batchJsonInputs", "batchYamlInputs"} {
batchStepMetadataRaw, ok := res.Steps[batchStepName].Metadata.(string)
assert.True(t, ok, "wrong type of metadata for step '%s'", batchStepName)

assert.Nil(t, res.Steps[batchStepName].Output, "output nil for step '%s'", batchStepName)

// The plugin formats Metadata in a special way that we need to revert before unmarshalling them
batchStepMetadataRaw = strings.ReplaceAll(batchStepMetadataRaw, `\"`, `"`)
var batchStepMetadata map[string]any
err := json.Unmarshal([]byte(batchStepMetadataRaw), &batchStepMetadata)
require.Nil(t, err, "metadata unmarshalling of step '%s'", batchStepName)

batchPublicID := batchStepMetadata["batch_id"].(string)
assert.NotEqual(t, "", batchPublicID, "wrong batch ID '%s'", batchPublicID)

b, err := task.LoadBatchFromPublicID(dbp, batchPublicID)
require.Nil(t, err)

taskIDs, err := listBatchTasks(dbp, b.ID)
require.Nil(t, err)
assert.Len(t, taskIDs, 2)

for i, publicID := range taskIDs {
child, err := task.LoadFromPublicID(dbp, publicID)
require.Nil(t, err)
assert.Equal(t, task.StateTODO, child.State)

childResolution, err := resolution.Create(dbp, child, nil, "", false, nil)
require.Nil(t, err)

childResolution, err = runResolution(childResolution)
require.Nil(t, err)
assert.Equal(t, resolution.StateDone, childResolution.State)

for k, v := range childResolution.Steps {
assert.Equal(t, step.StateDone, v.State, "not valid state for step %s", k)
}

child, err = task.LoadFromPublicID(dbp, child.PublicID)
require.Nil(t, err)
assert.Equal(t, task.StateDone, child.State)

parentTaskToResume, err := taskutils.ShouldResumeParentTask(dbp, child)
require.Nil(t, err)
if i == len(taskIDs)-1 {
// Only the last child task should resume the parent
require.NotNil(t, parentTaskToResume)
assert.Equal(t, res.TaskID, parentTaskToResume.ID)
} else {
require.Nil(t, parentTaskToResume)
}
}
}

// checking if the parent task is picked up after that the subtask is resolved.
// need to sleep a bit because the parent task is resumed asynchronously
ti := time.Second
i := time.Duration(0)
for i < ti {
res, err = resolution.LoadFromPublicID(dbp, res.PublicID)
require.Nil(t, err)
if res.State != resolution.StateWaiting {
break
}

time.Sleep(time.Millisecond * 10)
i += time.Millisecond * 10
}

ti = time.Second
i = time.Duration(0)
for i < ti {
res, err = resolution.LoadFromPublicID(dbp, res.PublicID)
require.Nil(t, err)
if res.State != resolution.StateRunning {
break
}

time.Sleep(time.Millisecond * 10)
i += time.Millisecond * 10

}
assert.Equal(t, resolution.StateDone, res.State)
}
26 changes: 26 additions & 0 deletions engine/templates_tests/batch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: batchTemplate
description: Template to test the batch plugin
title_format: "[test] batch template test"

steps:
batchJsonInputs:
description: Batching tasks JSON
action:
type: batch
configuration:
template_name: batchedtasktemplate
json_inputs: '[{"specific_string": "specific-1"}, {"specific_string": "specific-2"}]'
common_json_inputs: '{"common_string": "common"}'
sub_batch_size: 2
batchYamlInputs:
description: Batching tasks YAML
action:
type: batch
configuration:
template_name: batchedtasktemplate
inputs:
- specific_string: specific-1
- specific_string: specific-2
common_inputs:
common_string: common
sub_batch_size: 2
23 changes: 23 additions & 0 deletions engine/templates_tests/batchedTask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: batchedTaskTemplate
description: Template made to be spawned by the testing batch plugin
title_format: "[test] batched task template"

inputs:
- name: specific_string
description: A string specific to this task
type: string
- name: common_string
description: A string common to all tasks in the same batch
type: string

steps:
simpleStep:
description: Simple step
action:
type: echo
configuration:
output: >-
{
"specific": "{{.input.specific_string}}",
"common": "{{.input.common_string}}"
}
Loading

0 comments on commit dd6039e

Please sign in to comment.