Skip to content

Commit

Permalink
[Inference API] Prevent inference endpoints from being deleted if the…
Browse files Browse the repository at this point in the history
…y are referenced by semantic text (#110399)

Following on #109123
  • Loading branch information
maxhniebergall authored Jul 3, 2024
1 parent b6e9860 commit 89a1bd9
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 52 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/110399.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 110399
summary: "[Inference API] Prevent inference endpoints from being deleted if they are\
\ referenced by semantic text"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ static TransportVersion def(int id) {
public static final TransportVersion TEXT_SIMILARITY_RERANKER_RETRIEVER = def(8_699_00_0);
public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0);
public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0);
public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_702_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.XContentBuilder;

Expand Down Expand Up @@ -105,10 +106,16 @@ public static class Response extends AcknowledgedResponse {

private final String PIPELINE_IDS = "pipelines";
Set<String> pipelineIds;
private final String REFERENCED_INDEXES = "indexes";
Set<String> indexes;
private final String DRY_RUN_MESSAGE = "error_message"; // error message only returned in response for dry_run
String dryRunMessage;

public Response(boolean acknowledged, Set<String> pipelineIds) {
public Response(boolean acknowledged, Set<String> pipelineIds, Set<String> semanticTextIndexes, @Nullable String dryRunMessage) {
super(acknowledged);
this.pipelineIds = pipelineIds;
this.indexes = semanticTextIndexes;
this.dryRunMessage = dryRunMessage;
}

public Response(StreamInput in) throws IOException {
Expand All @@ -118,6 +125,15 @@ public Response(StreamInput in) throws IOException {
} else {
pipelineIds = Set.of();
}

if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) {
indexes = in.readCollectionAsSet(StreamInput::readString);
dryRunMessage = in.readOptionalString();
} else {
indexes = Set.of();
dryRunMessage = null;
}

}

@Override
Expand All @@ -126,12 +142,18 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ENHANCE_DELETE_ENDPOINT)) {
out.writeCollection(pipelineIds, StreamOutput::writeString);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) {
out.writeCollection(indexes, StreamOutput::writeString);
out.writeOptionalString(dryRunMessage);
}
}

@Override
protected void addCustomFields(XContentBuilder builder, Params params) throws IOException {
super.addCustomFields(builder, params);
builder.field(PIPELINE_IDS, pipelineIds);
builder.field(REFERENCED_INDEXES, indexes);
builder.field(DRY_RUN_MESSAGE, dryRunMessage);
}

@Override
Expand All @@ -142,6 +164,11 @@ public String toString() {
for (String entry : pipelineIds) {
returnable.append(entry).append(", ");
}
returnable.append(", semanticTextFieldsByIndex: ");
for (String entry : indexes) {
returnable.append(entry).append(", ");
}
returnable.append(", dryRunMessage: ").append(dryRunMessage);
return returnable.toString();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* this file was contributed to by a Generative AI
*/

package org.elasticsearch.xpack.core.ml.utils;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.metadata.Metadata;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class SemanticTextInfoExtractor {
private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class);

public static Set<String> extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set<String> endpointIds) {
assert endpointIds.isEmpty() == false;
assert metadata != null;

Set<String> referenceIndices = new HashSet<>();

Map<String, IndexMetadata> indices = metadata.indices();

indices.forEach((indexName, indexMetadata) -> {
if (indexMetadata.getInferenceFields() != null) {
Map<String, InferenceFieldMetadata> inferenceFields = indexMetadata.getInferenceFields();
if (inferenceFields.entrySet()
.stream()
.anyMatch(
entry -> entry.getValue().getInferenceId() != null && endpointIds.contains(entry.getValue().getInferenceId())
)) {
referenceIndices.add(indexName);
}
}
});

return referenceIndices;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,25 @@ protected void deleteModel(String modelId, TaskType taskType) throws IOException
assertOkOrCreated(response);
}

protected void putSemanticText(String endpointId, String indexName) throws IOException {
var request = new Request("PUT", Strings.format("%s", indexName));
String body = Strings.format("""
{
"mappings": {
"properties": {
"inference_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
}
""", endpointId);
request.setJsonEntity(body);
var response = client().performRequest(request);
assertOkOrCreated(response);
}

protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
return putRequest(endpoint, modelConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import java.io.IOException;
import java.util.List;
import java.util.Set;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -124,14 +125,15 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException {
putPipeline(pipelineId, endpointId);

{
var errorString = new StringBuilder().append("Inference endpoint ")
.append(endpointId)
.append(" is referenced by pipelines: ")
.append(Set.of(pipelineId))
.append(". ")
.append("Ensure that no pipelines are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.");
var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
assertThat(
e.getMessage(),
containsString(
"Inference endpoint endpoint_referenced_by_pipeline is referenced by pipelines and cannot be deleted. "
+ "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it."
)
);
assertThat(e.getMessage(), containsString(errorString.toString()));
}
{
var response = deleteModel(endpointId, "dry_run=true");
Expand All @@ -146,4 +148,76 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException {
}
deletePipeline(pipelineId);
}

public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException {
String endpointId = "endpoint_referenced_by_semantic_text";
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
String indexName = randomAlphaOfLength(10).toLowerCase();
putSemanticText(endpointId, indexName);
{

var errorString = new StringBuilder().append(" Inference endpoint ")
.append(endpointId)
.append(" is being used in the mapping for indexes: ")
.append(Set.of(indexName))
.append(". ")
.append("Ensure that no index mappings are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.");
var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
assertThat(e.getMessage(), containsString(errorString.toString()));
}
{
var response = deleteModel(endpointId, "dry_run=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":false"));
assertThat(entityString, containsString(indexName));
}
{
var response = deleteModel(endpointId, "force=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":true"));
}
}

public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws IOException {
String endpointId = "endpoint_referenced_by_semantic_text";
putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
String indexName = randomAlphaOfLength(10).toLowerCase();
putSemanticText(endpointId, indexName);
var pipelineId = "pipeline_referencing_model";
putPipeline(pipelineId, endpointId);
{

var errorString = new StringBuilder().append("Inference endpoint ")
.append(endpointId)
.append(" is referenced by pipelines: ")
.append(Set.of(pipelineId))
.append(". ")
.append("Ensure that no pipelines are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.")
.append(" Inference endpoint ")
.append(endpointId)
.append(" is being used in the mapping for indexes: ")
.append(Set.of(indexName))
.append(". ")
.append("Ensure that no index mappings are using this inference endpoint, ")
.append("or use force to ignore this warning and delete the inference endpoint.");

var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId));
assertThat(e.getMessage(), containsString(errorString.toString()));
}
{
var response = deleteModel(endpointId, "dry_run=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":false"));
assertThat(entityString, containsString(indexName));
assertThat(entityString, containsString(pipelineId));
}
{
var response = deleteModel(endpointId, "force=true");
var entityString = EntityUtils.toString(response.getEntity());
assertThat(entityString, containsString("\"acknowledged\":true"));
}
deletePipeline(pipelineId);
}
}
Loading

0 comments on commit 89a1bd9

Please sign in to comment.