Skip to content

Commit

Permalink
set model state as unloaded when call unload model API (opensearch-pr…
Browse files Browse the repository at this point in the history
…oject#580) (opensearch-project#619)

Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
b4sjoo and ylwu-amzn authored Dec 5, 2022
1 parent 4c31511 commit ad065d4
Showing 1 changed file with 40 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.ml.action.unload;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
import static org.opensearch.ml.common.CommonValue.UNLOADED;
import static org.opensearch.ml.common.MLModel.MODEL_STATE_FIELD;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -21,14 +23,18 @@

import org.opensearch.action.ActionListener;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
Expand All @@ -43,6 +49,8 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableMap;

@Log4j2
public class TransportUnloadModelAction extends
TransportNodesAction<UnloadModelNodesRequest, UnloadModelNodesResponse, UnloadModelNodeRequest, UnloadModelNodeResponse> {
Expand Down Expand Up @@ -116,18 +124,43 @@ protected UnloadModelNodesResponse newResponse(

MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client
.execute(
MLSyncUpAction.INSTANCE,
syncUpRequest,
ActionListener
.wrap(r -> log.debug("sync up removed nodes successfully"), e -> log.error("failed to sync up removed node", e))
);
if (removedNodeMap.size() > 0) {
BulkRequest bulkRequest = new BulkRequest();
for (String modelId : removedNodeMap.keySet()) {
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(ImmutableMap.of(MODEL_STATE_FIELD, MLModelState.UNLOADED));
bulkRequest.add(updateRequest);
}
ActionListener<BulkResponse> actionListenr = ActionListener
.wrap(
r -> {
log
.debug(
"updated model state as unloaded for : {}",
Arrays.toString(removedNodeMap.keySet().toArray(new String[0]))
);
},
e -> { log.error("Failed to update model state as unloaded", e); }
);
client.bulk(bulkRequest, ActionListener.runAfter(actionListenr, () -> { syncUpUnloadedModels(syncUpRequest); }));
} else {
syncUpUnloadedModels(syncUpRequest);
}
}
}
return new UnloadModelNodesResponse(clusterService.getClusterName(), responses, failures);
}

private void syncUpUnloadedModels(MLSyncUpNodesRequest syncUpRequest) {
client
.execute(
MLSyncUpAction.INSTANCE,
syncUpRequest,
ActionListener
.wrap(r -> log.debug("sync up removed nodes successfully"), e -> log.error("failed to sync up removed node", e))
);
}

@Override
protected UnloadModelNodeRequest newNodeRequest(UnloadModelNodesRequest request) {
return new UnloadModelNodeRequest(request);
Expand Down

0 comments on commit ad065d4

Please sign in to comment.