Skip to content

Commit

Permalink
disttask: subtasks rebalance during task execution (pingcap#48306)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywqzzy authored Nov 30, 2023
1 parent 6ec3a10 commit ef02d72
Show file tree
Hide file tree
Showing 12 changed files with 834 additions and 114 deletions.
1 change: 0 additions & 1 deletion pkg/ddl/backfilling_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ func generateNonPartitionPlan(
if err != nil {
return nil, err
}

regionBatch := calculateRegionBatch(len(recordRegionMetas), instanceCnt, !useCloud)

subTaskMetas := make([][]byte, 0, 4)
Expand Down
3 changes: 2 additions & 1 deletion pkg/disttask/framework/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ go_test(
"dispatcher_manager_test.go",
"dispatcher_test.go",
"main_test.go",
"rebalance_test.go",
],
embed = [":dispatcher"],
flaky = True,
race = "off",
shard_count = 16,
shard_count = 19,
deps = [
"//pkg/disttask/framework/dispatcher/mock",
"//pkg/disttask/framework/mock",
Expand Down
165 changes: 124 additions & 41 deletions pkg/disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ type BaseDispatcher struct {
// when RegisterDispatcherFactory, the factory MUST initialize this field.
Extension

// for HA
// liveNodes will fetch and store all live nodes every liveNodeInterval ticks.
liveNodes []*infosync.ServerInfo
// For subtasks rebalance.
// LiveNodes will fetch and store all live nodes every liveNodeInterval ticks.
LiveNodes []*infosync.ServerInfo
liveNodeFetchInterval int
// liveNodeFetchTick is the tick variable.
liveNodeFetchTick int
// taskNodes stores the id of current scheduler nodes.
taskNodes []string
// TaskNodes stores the id of current scheduler nodes.
TaskNodes []string

// rand is for generating random selection of nodes.
rand *rand.Rand
}
Expand All @@ -117,10 +118,10 @@ func NewBaseDispatcher(ctx context.Context, taskMgr TaskManager, serverID string
Task: task,
logCtx: logCtx,
serverID: serverID,
liveNodes: nil,
LiveNodes: nil,
liveNodeFetchInterval: DefaultLiveNodesCheckInterval,
liveNodeFetchTick: 0,
taskNodes: nil,
TaskNodes: nil,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
Expand Down Expand Up @@ -264,7 +265,7 @@ func (d *BaseDispatcher) onPausing() error {
// MockDMLExecutionOnPausedState is used to mock DML execution when tasks paused.
var MockDMLExecutionOnPausedState func(task *proto.Task)

// handle task in paused state
// handle task in paused state.
func (d *BaseDispatcher) onPaused() error {
logutil.Logger(d.logCtx).Info("on paused state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
failpoint.Inject("mockDMLExecutionOnPausedState", func(val failpoint.Value) {
Expand All @@ -278,7 +279,7 @@ func (d *BaseDispatcher) onPaused() error {
// TestSyncChan is used to sync the test.
var TestSyncChan = make(chan struct{})

// handle task in resuming state
// handle task in resuming state.
func (d *BaseDispatcher) onResuming() error {
logutil.Logger(d.logCtx).Info("on resuming state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStatePaused)
Expand Down Expand Up @@ -348,8 +349,8 @@ func (d *BaseDispatcher) onRunning() error {
if cnt == 0 {
return d.onNextStage()
}
// Check if any node are down.
if err := d.replaceDeadNodesIfAny(); err != nil {

if err := d.BalanceSubtasks(); err != nil {
return err
}
// Wait all subtasks in this stage finished.
Expand All @@ -364,16 +365,19 @@ func (d *BaseDispatcher) onFinished() error {
return d.taskMgr.TransferSubTasks2History(d.ctx, d.Task.ID)
}

func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
if len(d.taskNodes) == 0 {
// BalanceSubtasks check the liveNode num every liveNodeFetchInterval then rebalance subtasks.
func (d *BaseDispatcher) BalanceSubtasks() error {
// 1. init TaskNodes if needed.
if len(d.TaskNodes) == 0 {
var err error
d.taskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, d.Task.ID, d.Task.Step)
d.TaskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, d.Task.ID, d.Task.Step)
if err != nil {
return err
}
}
d.liveNodeFetchTick++
if d.liveNodeFetchTick == d.liveNodeFetchInterval {
// 2. update LiveNodes.
d.liveNodeFetchTick = 0
serverInfos, err := GenerateSchedulerNodes(d.ctx)
if err != nil {
Expand Down Expand Up @@ -403,37 +407,116 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
newInfos = append(newInfos, m)
}
}
d.liveNodes = newInfos
d.LiveNodes = newInfos
// 3. balance subtasks.
if len(d.LiveNodes) > 0 {
return d.ReDispatchSubtasks()
}
return nil
}
if len(d.liveNodes) > 0 {
replaceNodes := make(map[string]string)
cleanNodes := make([]string, 0)
for _, nodeID := range d.taskNodes {
if ok := disttaskutil.MatchServerInfo(d.liveNodes, nodeID); !ok {
n := d.liveNodes[d.rand.Int()%len(d.liveNodes)] //nolint:gosec
replaceNodes[nodeID] = disttaskutil.GenerateExecID(n.IP, n.Port)
cleanNodes = append(cleanNodes, nodeID)
}
return nil
}

func (d *BaseDispatcher) replaceTaskNodes() {
d.TaskNodes = d.TaskNodes[:0]
for _, serverInfo := range d.LiveNodes {
d.TaskNodes = append(d.TaskNodes, disttaskutil.GenerateExecID(serverInfo.IP, serverInfo.Port))
}
}

// ReDispatchSubtasks make count of subtasks on each liveNodes balanced and clean up subtasks on dead nodes.
// TODO(ywqzzy): refine to make it easier for testing.
func (d *BaseDispatcher) ReDispatchSubtasks() error {
// 1. find out nodes need to clean subtasks.
deadNodes := make([]string, 0)
deadNodesMap := make(map[string]bool, 0)
for _, node := range d.TaskNodes {
if !disttaskutil.MatchServerInfo(d.LiveNodes, node) {
deadNodes = append(deadNodes, node)
deadNodesMap[node] = true
}
if len(replaceNodes) > 0 {
logutil.Logger(d.logCtx).Info("reschedule subtasks to other nodes", zap.Int("node-cnt", len(replaceNodes)))
if err := d.taskMgr.UpdateFailedSchedulerIDs(d.ctx, d.Task.ID, replaceNodes); err != nil {
return err
}
if err := d.taskMgr.CleanUpMeta(d.ctx, cleanNodes); err != nil {
return err
}
// 2. get subtasks for each node before rebalance.
subtasks, err := d.taskMgr.GetSubtasksByStepAndState(d.ctx, d.Task.ID, d.Task.Step, proto.TaskStatePending)
if err != nil {
return err
}
if len(deadNodes) != 0 {
/// get subtask from deadNodes, since there might be some running subtasks on deadNodes.
/// In this case, all subtasks on deadNodes are in running/pending state.
subtasksOnDeadNodes, err := d.taskMgr.GetSubtasksByExecIdsAndStepAndState(d.ctx, deadNodes, d.Task.ID, d.Task.Step, proto.TaskStateRunning)
if err != nil {
return err
}
subtasks = append(subtasks, subtasksOnDeadNodes...)
}
// 3. group subtasks for each scheduler.
subtasksOnScheduler := make(map[string][]*proto.Subtask, len(d.LiveNodes)+len(deadNodes))
for _, node := range d.LiveNodes {
execID := disttaskutil.GenerateExecID(node.IP, node.Port)
subtasksOnScheduler[execID] = make([]*proto.Subtask, 0)
}
for _, subtask := range subtasks {
subtasksOnScheduler[subtask.SchedulerID] = append(
subtasksOnScheduler[subtask.SchedulerID],
subtask)
}
// 4. prepare subtasks that need to rebalance to other nodes.
averageSubtaskCnt := len(subtasks) / len(d.LiveNodes)
rebalanceSubtasks := make([]*proto.Subtask, 0)
for k, v := range subtasksOnScheduler {
if ok := deadNodesMap[k]; ok {
rebalanceSubtasks = append(rebalanceSubtasks, v...)
continue
}
// When no tidb scale-in/out and averageSubtaskCnt*len(d.LiveNodes) < len(subtasks),
// no need to send subtask to other nodes.
// eg: tidb1 with 3 subtasks, tidb2 with 2 subtasks, subtasks are balanced now.
if averageSubtaskCnt*len(d.LiveNodes) < len(subtasks) && len(d.TaskNodes) == len(d.LiveNodes) {
if len(v) > averageSubtaskCnt+1 {
rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...)
}
// replace local cache.
for k, v := range replaceNodes {
for m, n := range d.taskNodes {
if n == k {
d.taskNodes[m] = v
break
}
continue
}
if len(v) > averageSubtaskCnt {
rebalanceSubtasks = append(rebalanceSubtasks, v[0:len(v)-averageSubtaskCnt]...)
}
}
// 5. skip rebalance.
if len(rebalanceSubtasks) == 0 {
return nil
}
// 6.rebalance subtasks to other nodes.
rebalanceIdx := 0
for k, v := range subtasksOnScheduler {
if ok := deadNodesMap[k]; !ok {
if len(v) < averageSubtaskCnt {
for i := 0; i < averageSubtaskCnt-len(v) && rebalanceIdx < len(rebalanceSubtasks); i++ {
rebalanceSubtasks[rebalanceIdx].SchedulerID = k
rebalanceIdx++
}
}
}
}
// 7. rebalance rest subtasks evenly to liveNodes.
liveNodeIdx := 0
for rebalanceIdx < len(rebalanceSubtasks) {
node := d.LiveNodes[liveNodeIdx]
rebalanceSubtasks[rebalanceIdx].SchedulerID = disttaskutil.GenerateExecID(node.IP, node.Port)
rebalanceIdx++
liveNodeIdx++
}

// 8. update subtasks and do clean up logic.
if err = d.taskMgr.UpdateSubtasksSchedulerIDs(d.ctx, d.Task.ID, subtasks); err != nil {
return err
}
logutil.Logger(d.logCtx).Info("rebalance subtasks",
zap.Stringers("subtasks-rebalanced", subtasks))
if err = d.taskMgr.CleanUpMeta(d.ctx, deadNodes); err != nil {
return err
}
d.replaceTaskNodes()
return nil
}

Expand Down Expand Up @@ -605,9 +688,9 @@ func (d *BaseDispatcher) dispatchSubTask(
metas [][]byte,
serverNodes []*infosync.ServerInfo) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.Stringer("state", d.Task.State), zap.Int64("step", int64(d.Task.Step)), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))
d.taskNodes = make([]string, len(serverNodes))
d.TaskNodes = make([]string, len(serverNodes))
for i := range serverNodes {
d.taskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port)
d.TaskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port)
}
subTasks := make([]*proto.Subtask, 0, len(metas))
for i, meta := range metas {
Expand Down Expand Up @@ -718,7 +801,7 @@ func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Tas

// GetPreviousSubtaskMetas get subtask metas from specific step.
func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) {
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(d.ctx, taskID, step)
previousSubtasks, err := d.taskMgr.GetSubtasksByStepAndState(d.ctx, taskID, step, proto.TaskStateSucceed)
if err != nil {
logutil.Logger(d.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", int64(step)))
return nil, err
Expand Down
5 changes: 3 additions & 2 deletions pkg/disttask/framework/dispatcher/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ type TaskManager interface {
ResumeSubtasks(ctx context.Context, taskID int64) error
CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error)
TransferSubTasks2History(ctx context.Context, taskID int64) error
UpdateFailedSchedulerIDs(ctx context.Context, taskID int64, replaceNodes map[string]string) error
UpdateSubtasksSchedulerIDs(ctx context.Context, taskID int64, subtasks []*proto.Subtask) error
GetNodesByRole(ctx context.Context, role string) (map[string]bool, error)
GetSchedulerIDsByTaskID(ctx context.Context, taskID int64) ([]string, error)
GetSucceedSubtasksByStep(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error)
GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSchedulerIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error)

WithNewSession(fn func(se sessionctx.Context) error) error
Expand Down
Loading

0 comments on commit ef02d72

Please sign in to comment.