Skip to content

Commit

Permalink
Add more detailed error messages for KNN model training
Browse files Browse the repository at this point in the history
Signed-off-by: AnnTian Shao <anntians@amazon.com>
  • Loading branch information
AnnTian Shao committed Jan 9, 2025
1 parent dc369e6 commit ef0e7e5
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.common.ValidationException;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequestOptions;
Expand All @@ -31,6 +32,9 @@
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER;

/**
Expand Down Expand Up @@ -134,6 +138,30 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
trainingVectors = trainingModelRequest.getMaximumVectorCount();
}

long minTrainingVectorCount = 1000;
MethodComponentContext encoderContext = (MethodComponentContext) trainingModelRequest.getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (trainingModelRequest.getKnnMethodContext().getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) trainingModelRequest.getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
ValidationException exception = new ValidationException();
exception.addValidationError("Number of training points should be greater than " + minTrainingVectorCount);
listener.onFailure(exception);
return;
}

listener.onResponse(
estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

import java.io.IOException;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;

/**
* Request to train and serialize a model
*/
Expand Down Expand Up @@ -283,6 +285,15 @@ public ActionRequestValidationException validate() {
exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters");
}

// Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
}

// Validate training index exists
IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex);
if (indexMetadata == null) {
Expand Down
4 changes: 1 addition & 3 deletions src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ public void run() {
} catch (Exception e) {
logger.error("Failed to run training job for model \"" + modelId + "\": ", e);
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(
"Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training."
);
modelMetadata.setError("Failed to execute training. " + e.getMessage());

KNNCounter.TRAINING_ERRORS.increment();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -344,6 +345,55 @@ public void testTrainingIndexSize() {
transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener);
}

public void testTrainingIndexSizeFailure() {

String trainingIndexName = "training-index";
int dimension = 133;
int vectorCount = 100;

// Setup the request
TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
null,
getDefaultKNNMethodContextForModel(),
dimension,
trainingIndexName,
"training-field",
null,
"description",
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
);

// Mock client to return the right number of docs
TotalHits totalHits = new TotalHits(vectorCount, TotalHits.Relation.EQUAL_TO);
SearchHits searchHits = new SearchHits(new SearchHit[2], totalHits, 1.0f);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(searchHits);
Client client = mock(Client.class);
doAnswer(invocationOnMock -> {
((ActionListener<SearchResponse>) invocationOnMock.getArguments()[1]).onResponse(searchResponse);
return null;
}).when(client).search(any(), any());

// Setup the action
ClusterService clusterService = mock(ClusterService.class);
TransportService transportService = mock(TransportService.class);
TrainingJobRouterTransportAction transportAction = new TrainingJobRouterTransportAction(
transportService,
new ActionFilters(Collections.emptySet()),
clusterService,
client
);

ActionListener<Integer> listener = ActionListener.wrap(
size -> size.intValue(),
e -> assertThat(e.getMessage(), containsString("Number of training points should be greater than"))
);

transportAction.getTrainingIndexSizeInKB(trainingModelRequest, listener);
}

public void testTrainIndexSize_whenDataTypeIsBinary() {
String trainingIndexName = "training-index";
int dimension = 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,61 @@ public void testValidation_invalid_descriptionToLong() {
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
logger.error("Validation errorsa " + validationErrors);
logger.error("Validation errors " + validationErrors);
assertEquals(1, validationErrors.size());
assertTrue(validationErrors.get(0).contains("Description exceeds limit"));
}

public void testValidation_invalid_mNotDivisibleByDimension() {

// Setup the training request
String modelId = "test-model-id";
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
String trainingFieldModeId = "training-field-model-id";

Map<String, Object> parameters = Map.of("m", 3);

MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, parameters);
final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.DEFAULT, methodComponentContext);

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
dimension,
trainingIndex,
trainingField,
null,
null,
VectorDataType.DEFAULT,
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
);

// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);

ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(null);
when(modelDao.getMetadata(trainingFieldModeId)).thenReturn(trainingFieldModelMetadata);

// Cluster service that wont produce validation exception
ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension);

// Initialize static components with the mocks
TrainingModelRequest.initialize(modelDao, clusterService);

// Test that validation produces m not divisible by vector dimension error message
ActionRequestValidationException exception = trainingModelRequest.validate();
assertNotNull(exception);
List<String> validationErrors = exception.validationErrors();
logger.error("Validation errors " + validationErrors);
assertEquals(2, validationErrors.size());
assertTrue(validationErrors.get(1).contains("Training request ENCODER_PARAMETER_PQ_M"));
}

public void testValidation_valid_trainingIndexBuiltFromMethod() {
// This cluster service will result in no validation exceptions

Expand Down
94 changes: 91 additions & 3 deletions src/test/java/org/opensearch/knn/training/TrainingJobTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.util.UUID;
import java.util.concurrent.ExecutionException;

import static org.hamcrest.Matchers.containsString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -217,7 +218,6 @@ public void testRun_success() throws IOException, ExecutionException {

Model model = trainingJob.getModel();
assertNotNull(model);

assertEquals(ModelState.CREATED, model.getModelMetadata().getState());

// Simple test that creates the index from template and doesnt fail
Expand Down Expand Up @@ -308,6 +308,10 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept

Model model = trainingJob.getModel();
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
assertThat(
"Failed to load training data into memory. " + "Check if there is enough memory to perform the request.",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertFalse(model.getModelMetadata().getError().isEmpty());
}
Expand Down Expand Up @@ -382,6 +386,10 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce

Model model = trainingJob.getModel();
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
assertThat(
"Failed to allocate space in native memory for the model. " + "Check if there is enough memory to perform the request.",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertFalse(model.getModelMetadata().getError().isEmpty());
}
Expand Down Expand Up @@ -435,15 +443,91 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep
when(nativeMemoryAllocation.isClosed()).thenReturn(true);
when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0);

// Throw error on getting data
// Throw error on allocation is closed
when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation);

TrainingJob trainingJob = new TrainingJob(
modelId,
knnMethodContext,
nativeMemoryCacheManager,
trainingDataEntryContext,
mock(NativeMemoryEntryContext.AnonymousEntryContext.class),
modelContext,
knnMethodConfigContext,
"",
"test-node",
Mode.NOT_CONFIGURED,
CompressionLevel.NOT_CONFIGURED
);

trainingJob.run();

Model model = trainingJob.getModel();
assertThat(
"Failed to execute training. Unable to load training data into memory: allocation is already closed",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
}

public void testRun_failure_closedModelAnonymousAllocation() throws ExecutionException {
// In this test, the model anonymous allocation should be closed. Then, run should fail and update the error of
// the model
String modelId = "test-model-id";

// Define the method setup for method that requires training
int nlists = 5;
int dimension = 16;
KNNEngine knnEngine = KNNEngine.FAISS;
KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(VectorDataType.FLOAT)
.dimension(dimension)
.versionCreated(Version.CURRENT)
.build();
KNNMethodContext knnMethodContext = new KNNMethodContext(
knnEngine,
SpaceType.INNER_PRODUCT,
new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists))
);

String tdataKey = "t-data-key";
NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock(
NativeMemoryEntryContext.TrainingDataEntryContext.class
);
when(trainingDataEntryContext.getKey()).thenReturn(tdataKey);

// Setup model manager
NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class);

// Setup mock allocation for model that's closed
NativeMemoryAllocation modelAllocation = mock(NativeMemoryAllocation.class);
doAnswer(invocationOnMock -> null).when(modelAllocation).readLock();
doAnswer(invocationOnMock -> null).when(modelAllocation).readUnlock();
when(modelAllocation.isClosed()).thenReturn(true);

String modelKey = "model-test-key";
NativeMemoryEntryContext.AnonymousEntryContext modelContext = mock(NativeMemoryEntryContext.AnonymousEntryContext.class);
when(modelContext.getKey()).thenReturn(modelKey);

// Throw error on allocation is closed
when(nativeMemoryCacheManager.get(modelContext, false)).thenReturn(modelAllocation);
doAnswer(invocationOnMock -> null).when(nativeMemoryCacheManager).invalidate(modelKey);

// Setup mock allocation thats not closed
NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class);
doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readLock();
doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readUnlock();
when(nativeMemoryAllocation.isClosed()).thenReturn(false);
when(nativeMemoryAllocation.getMemoryAddress()).thenReturn((long) 0);

when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation);

TrainingJob trainingJob = new TrainingJob(
modelId,
knnMethodContext,
nativeMemoryCacheManager,
trainingDataEntryContext,
modelContext,
knnMethodConfigContext,
"",
"test-node",
Expand All @@ -454,6 +538,10 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep
trainingJob.run();

Model model = trainingJob.getModel();
assertThat(
"Failed to execute training. Unable to reserve memory for model: allocation is already closed",
containsString(trainingJob.getModel().getModelMetadata().getError())
);
assertNotNull(model);
assertEquals(ModelState.FAILED, trainingJob.getModel().getModelMetadata().getState());
}
Expand Down

0 comments on commit ef0e7e5

Please sign in to comment.