Skip to content

Commit

Permalink
[ML] Remove mention of models in inference actions (elastic#107704)
Browse files Browse the repository at this point in the history
Renames the GET, PUT and DELETE inference APIs removing the model parts:
inference.delete_model -> inference.delete
inference.get_model -> inference.get
inference.put -> inference.put
The GET response now has a endpoints field instead of models
  • Loading branch information
davidkyle authored May 1, 2024
1 parent b9a5f89 commit 53252c6
Show file tree
Hide file tree
Showing 20 changed files with 138 additions and 121 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"inference.delete_model":{
"inference.delete":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/delete-inference-api.html",
"description":"Delete model in the Inference API"
"description":"Delete an inference endpoint"
},
"stability":"experimental",
"visibility":"public",
Expand Down Expand Up @@ -35,7 +35,7 @@
},
"inference_id":{
"type":"string",
"description":"The model Id"
"description":"The inference Id"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"inference.get_model":{
"inference.get":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/get-inference-api.html",
"description":"Get a model in the Inference API"
"description":"Get an inference endpoint"
},
"stability":"experimental",
"visibility":"public",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"inference.inference":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/post-inference-api.html",
"description":"Perform inference on a model"
"description":"Perform inference"
},
"stability":"experimental",
"visibility":"public",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"inference.put_model":{
"inference.put":{
"documentation":{
"url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/put-inference-api.html",
"description":"Configure a model for use in the Inference API"
"description":"Configure an inference endpoint for use in the Inference API"
},
"stability":"experimental",
"visibility":"public",
Expand Down Expand Up @@ -43,7 +43,7 @@
]
},
"body":{
"description":"The model's task and service settings"
"description":"The inference endpoint's task and service settings"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,42 +79,42 @@ public int hashCode() {

public static class Response extends ActionResponse implements ToXContentObject {

private final List<ModelConfigurations> models;
private final List<ModelConfigurations> endpoints;

public Response(List<ModelConfigurations> models) {
this.models = models;
public Response(List<ModelConfigurations> endpoints) {
this.endpoints = endpoints;
}

public Response(StreamInput in) throws IOException {
super(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
models = in.readCollectionAsList(ModelConfigurations::new);
endpoints = in.readCollectionAsList(ModelConfigurations::new);
} else {
models = new ArrayList<>();
models.add(new ModelConfigurations(in));
endpoints = new ArrayList<>();
endpoints.add(new ModelConfigurations(in));
}
}

public List<ModelConfigurations> getModels() {
return models;
public List<ModelConfigurations> getEndpoints() {
return endpoints;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
out.writeCollection(models);
out.writeCollection(endpoints);
} else {
models.get(0).writeTo(out);
endpoints.get(0).writeTo(out);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.startArray("models");
for (var model : models) {
if (model != null) {
model.toFilteredXContent(builder, params);
builder.startArray("endpoints");
for (var endpoint : endpoints) {
if (endpoint != null) {
endpoint.toFilteredXContent(builder, params);
}
}
builder.endArray();
Expand All @@ -127,12 +127,12 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
GetInferenceModelAction.Response response = (GetInferenceModelAction.Response) o;
return Objects.equals(models, response.models);
return Objects.equals(endpoints, response.endpoints);
}

@Override
public int hashCode() {
return Objects.hash(models);
return Objects.hash(endpoints);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,25 @@ protected Map<String, Object> deployE5TrainedModels() throws IOException {
return entityAsMap(response);
}

@SuppressWarnings("unchecked")
protected Map<String, Object> getModel(String modelId) throws IOException {
var endpoint = Strings.format("_inference/%s", modelId);
return getAllModelInternal(endpoint);
return ((List<Map<String, Object>>) getInternal(endpoint).get("endpoints")).get(0);
}

protected Map<String, Object> getModels(String modelId, TaskType taskType) throws IOException {
@SuppressWarnings("unchecked")
protected List<Map<String, Object>> getModels(String modelId, TaskType taskType) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
return getAllModelInternal(endpoint);
return (List<Map<String, Object>>) getInternal(endpoint).get("endpoints");
}

protected Map<String, Object> getAllModels() throws IOException {
@SuppressWarnings("unchecked")
protected List<Map<String, Object>> getAllModels() throws IOException {
var endpoint = Strings.format("_inference/_all");
return getAllModelInternal("_inference/_all");
return (List<Map<String, Object>>) getInternal("_inference/_all").get("endpoints");
}

private Map<String, Object> getAllModelInternal(String endpoint) throws IOException {
private Map<String, Object> getInternal(String endpoint) throws IOException {
var request = new Request("GET", endpoint);
var response = client().performRequest(request);
assertOkOrCreated(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
Expand All @@ -31,22 +30,22 @@ public void testGet() throws IOException {
putModel("te_model_" + i, mockSparseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
}

var getAllModels = (List<Map<String, Object>>) getAllModels().get("models");
var getAllModels = getAllModels();
assertThat(getAllModels, hasSize(9));

var getSparseModels = (List<Map<String, Object>>) getModels("_all", TaskType.SPARSE_EMBEDDING).get("models");
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
assertThat(getSparseModels, hasSize(5));
for (var sparseModel : getSparseModels) {
assertEquals("sparse_embedding", sparseModel.get("task_type"));
}

var getDenseModels = (List<Map<String, Object>>) getModels("_all", TaskType.TEXT_EMBEDDING).get("models");
var getDenseModels = getModels("_all", TaskType.TEXT_EMBEDDING);
assertThat(getDenseModels, hasSize(4));
for (var denseModel : getDenseModels) {
assertEquals("text_embedding", denseModel.get("task_type"));
}

var singleModel = (List<Map<String, Object>>) getModels("se_model_1", TaskType.SPARSE_EMBEDDING).get("models");
var singleModel = getModels("se_model_1", TaskType.SPARSE_EMBEDDING);
assertThat(singleModel, hasSize(1));
assertEquals("se_model_1", singleModel.get(0).get("model_id"));

Expand All @@ -63,7 +62,7 @@ public void testGetModelWithWrongTaskType() throws IOException {
var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
assertThat(
e.getMessage(),
containsString("Requested task type [text_embedding] does not match the model's task type [sparse_embedding]")
containsString("Requested task type [text_embedding] does not match the inference endpoint's task type [sparse_embedding]")
);
}

Expand All @@ -72,15 +71,15 @@ public void testDeleteModelWithWrongTaskType() throws IOException {
var e = expectThrows(ResponseException.class, () -> deleteModel("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
assertThat(
e.getMessage(),
containsString("Requested task type [text_embedding] does not match the model's task type [sparse_embedding]")
containsString("Requested task type [text_embedding] does not match the inference endpoint's task type [sparse_embedding]")
);
}

@SuppressWarnings("unchecked")
public void testGetModelWithAnyTaskType() throws IOException {
String inferenceEntityId = "sparse_embedding_model";
putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var singleModel = (List<Map<String, Object>>) getModels(inferenceEntityId, TaskType.ANY).get("models");
var singleModel = getModels(inferenceEntityId, TaskType.ANY);
assertEquals(inferenceEntityId, singleModel.get(0).get("model_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
}
Expand All @@ -89,9 +88,9 @@ public void testGetModelWithAnyTaskType() throws IOException {
public void testApisWithoutTaskType() throws IOException {
String modelId = "no_task_type_in_url";
putModel(modelId, mockSparseServiceModelConfig(TaskType.SPARSE_EMBEDDING));
var singleModel = (List<Map<String, Object>>) getModel(modelId).get("models");
assertEquals(modelId, singleModel.get(0).get("model_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
var singleModel = getModel(modelId);
assertEquals(modelId, singleModel.get("model_id"));
assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));

var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10)));
assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ public class MockDenseInferenceServiceIT extends InferenceBaseRestTest {
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
var model = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING).get(0);

for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("model_id"));
Expand Down Expand Up @@ -51,8 +50,7 @@ public void testMockServiceWithMultipleInputs() throws IOException {
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
var model = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING).get(0);

var serviceSettings = (Map<String, Object>) model.get("service_settings");
assertNull(serviceSettings.get("api_key"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ public class MockSparseInferenceServiceIT extends InferenceBaseRestTest {
public void testMockService() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
var model = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING).get(0);

for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("model_id"));
Expand Down Expand Up @@ -53,8 +52,7 @@ public void testMockServiceWithMultipleInputs() throws IOException {
public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
var model = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING).get(0);

var serviceSettings = (Map<String, Object>) model.get("service_settings");
assertNull(serviceSettings.get("api_key"));
Expand All @@ -69,8 +67,7 @@ public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOExcepti
public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
var model = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING).get(0);

for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("model_id"));
Expand All @@ -88,8 +85,7 @@ public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws I
public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOException {
String inferenceEntityId = "test-mock";
var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(null, true), TaskType.SPARSE_EMBEDDING);
var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
var model = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING).get(0);

for (var modelMap : List.of(putModel, model)) {
assertEquals(inferenceEntityId, modelMap.get("model_id"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void testStoreModelWithUnknownFields() throws Exception {
statusException.getRootCause().getMessage(),
containsString("mapping set to strict, dynamic introduction of [unknown_field] within [_doc] is not allowed")
);
assertThat(exceptionHolder.get().getMessage(), containsString("Failed to store inference model [" + inferenceEntityId + "]"));
assertThat(exceptionHolder.get().getMessage(), containsString("Failed to store inference endpoint [" + inferenceEntityId + "]"));
}

public void testGetModel() throws Exception {
Expand Down Expand Up @@ -144,7 +144,7 @@ public void testStoreModelFailsWhenModelExists() throws Exception {
assertThat(exceptionHolder.get(), not(nullValue()));
assertThat(
exceptionHolder.get().getMessage(),
containsString("Inference model [test-put-trained-model-config-exists] already exists")
containsString("Inference endpoint [test-put-trained-model-config-exists] already exists")
);
}

Expand All @@ -171,7 +171,7 @@ public void testDeleteModel() throws Exception {

assertThat(exceptionHolder.get(), not(nullValue()));
assertFalse(deleteResponseHolder.get());
assertThat(exceptionHolder.get().getMessage(), containsString("Model not found [model1]"));
assertThat(exceptionHolder.get().getMessage(), containsString("Inference endpoint not found [model1]"));
}

public void testGetModelsByTaskType() throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.DeleteInferenceModelAction;
import org.elasticsearch.xpack.inference.common.InferenceExceptions;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

public class TransportDeleteInferenceModelAction extends AcknowledgedTransportMasterNodeAction<DeleteInferenceModelAction.Request> {
Expand Down Expand Up @@ -70,14 +71,7 @@ protected void masterOperation(

if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
// specific task type in request does not match the models
l1.onFailure(
new ElasticsearchStatusException(
"Requested task type [{}] does not match the model's task type [{}]",
RestStatus.BAD_REQUEST,
request.getTaskType(),
unparsedModel.taskType()
)
);
l1.onFailure(InferenceExceptions.mismatchedTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
return;
}
var service = serviceRegistry.getService(unparsedModel.service());
Expand Down
Loading

0 comments on commit 53252c6

Please sign in to comment.