Skip to content

Commit

Permalink
I18N-1336 - Improve AI checker perf
Browse files Browse the repository at this point in the history
Strings are checked concurrently
  • Loading branch information
DarKhaos committed Nov 8, 2024
1 parent 191ffee commit e87971f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -101,15 +103,46 @@ public AICheckResponse executeAIChecks(AICheckRequest aiCheckRequest) {
textUnit -> textUnit,
(existing, replacement) -> existing,
HashMap::new));
Map<String, List<AICheckResult>> results = new HashMap<>();

Map<String, CompletableFuture<List<AICheckResult>>> checkResultsBySource = new HashMap<>();
textUnitsUniqueSource
.values()
.forEach(
textUnit -> {
List<AICheckResult> aiCheckResults = checkString(textUnit, prompts, repository);
results.put(textUnit.getSource(), aiCheckResults);
CompletableFuture<List<AICheckResult>> checkResultFuture =
CompletableFuture.supplyAsync(() -> checkString(textUnit, prompts, repository));
checkResultsBySource.put(textUnit.getSource(), checkResultFuture);
});

CompletableFuture<Void> combinedCheckResults =
CompletableFuture.allOf(checkResultsBySource.values().toArray(new CompletableFuture[0]));

CompletableFuture<Map<String, List<AICheckResult>>> checkResults =
combinedCheckResults.thenApply(
v -> {
Map<String, List<AICheckResult>> allResultsList = new HashMap<>();
for (Map.Entry<String, CompletableFuture<List<AICheckResult>>>
checkResultFutureBySource : checkResultsBySource.entrySet()) {
try {
allResultsList.put(
checkResultFutureBySource.getKey(),
checkResultFutureBySource.getValue().get());
} catch (InterruptedException | ExecutionException e) {
logger.error("Error while running a completable future", e);
}
}
return allResultsList;
});

Map<String, List<AICheckResult>> results = new HashMap<>();
checkResults.thenAccept(results::putAll);

try {
checkResults.get();
} catch (InterruptedException | ExecutionException e) {
logger.error("Error while running completable futures", e);
}

AICheckResponse aiCheckResponse = new AICheckResponse();
aiCheckResponse.setResults(results);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,4 +830,68 @@ void testPromptTemplatingInlineSentence() {
The comment is: A friendly greeting. The plural form is: one.""",
prompt);
}

@Test
void testExecuteAIChecksWithSleepTime() {
AIPrompt prompt = new AIPrompt();
prompt.setId(1L);
prompt.setUserPrompt("Check strings for spelling");
prompt.setModelName("gtp-3.5-turbo");
prompt.setPromptTemperature(0.0F);
prompt.setContextMessages(new ArrayList<>());
AssetExtractorTextUnit assetExtractorTextUnit = new AssetExtractorTextUnit();
assetExtractorTextUnit.setSource("A test string");
assetExtractorTextUnit.setName("A test string --- A test context");
assetExtractorTextUnit.setComments("A test comment");
AssetExtractorTextUnit assetExtractorTextUnit2 = new AssetExtractorTextUnit();
assetExtractorTextUnit2.setSource("A test string 2");
assetExtractorTextUnit2.setName("A test string --- A test context 2");
assetExtractorTextUnit2.setComments("A test comment 2");
List<AssetExtractorTextUnit> textUnits =
List.of(assetExtractorTextUnit, assetExtractorTextUnit2);
AICheckRequest aiCheckRequest = new AICheckRequest();
aiCheckRequest.setRepositoryName("testRepo");
aiCheckRequest.setTextUnits(textUnits);
List<OpenAIClient.ChatCompletionsResponse.Choice> choices =
List.of(
new OpenAIClient.ChatCompletionsResponse.Choice(
0,
new OpenAIClient.ChatCompletionsResponse.Choice.Message(
"test", "{\"success\": true, \"suggestedFix\": \"\"}"),
null));
Repository repository = new Repository();
repository.setName("testRepo");
repository.setId(1L);
OpenAIClient.ChatCompletionsResponse chatCompletionsResponse =
new OpenAIClient.ChatCompletionsResponse(null, null, null, null, choices, null, null);
CompletableFuture<OpenAIClient.ChatCompletionsResponse> futureResponse =
CompletableFuture.completedFuture(chatCompletionsResponse);
List<AIPrompt> prompts = List.of(prompt);

when(repositoryRepository.findByName("testRepo")).thenReturn(repository);
when(promptService.getPromptsByRepositoryAndPromptType(
repository, PromptType.SOURCE_STRING_CHECKER))
.thenReturn(prompts);
doAnswer(
invocation -> {
Thread.sleep(100);
return futureResponse;
})
.when(openAIClient)
.getChatCompletions(any(OpenAIClient.ChatCompletionsRequest.class));

AICheckResponse response = this.openAILLMService.executeAIChecks(aiCheckRequest);

assertNotNull(response);
assertEquals(2, response.getResults().size());
assertTrue(response.getResults().containsKey("A test string"));
assertTrue(response.getResults().get("A test string").getFirst().isSuccess());

assertTrue(response.getResults().containsKey("A test string 2"));
assertTrue(response.getResults().get("A test string 2").getFirst().isSuccess());

verify(aiStringCheckRepository, times(2)).save(any());
verify(meterRegistry, times(2))
.counter("OpenAILLMService.checks.result", "success", "true", "repository", "testRepo");
}
}

0 comments on commit e87971f

Please sign in to comment.